diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 39e1e6e27..d83406db2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -283,6 +283,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> { let hasVerifier = 1; let hasFolder = 1; + let hasCanonicalizer = 1; } def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 45c131946..8a0b2e718 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -373,6 +373,8 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul let results = (outs Type.predicate, HasStaticShapePred]>>); let hasFolder = 1; + + let hasCanonicalizer = 1; } def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [TensorBroadcastingRules, TensorBinaryEint]> { @@ -748,6 +750,13 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEin let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; + + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + mlir::Value getClearMatrix() { return getRhs(); } + mlir::Value getEncryptedMatrix () { return getLhs(); } + }]; } def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> { @@ -886,6 +895,13 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; + + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + mlir::Value getClearMatrix() { return getLhs(); } + mlir::Value getEncryptedMatrix () { return getRhs(); } + }]; } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 72ad80226..e72a64d86 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -3,6 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeUtilities.h" @@ -345,6 +347,53 @@ OpFoldResult MulEintIntOp::fold(FoldAdaptor operands) { return nullptr; } +void MulEintIntOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + + // Replace multiplication by clear zero cst to a trivial encrypted zero tensor + class ZeroCstOpPattern : public mlir::OpRewritePattern { + public: + ZeroCstOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MulEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getB().getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + auto val = cstOp->getAttrOfType("value"); + if (val.getInt() != 0) { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + + // Replace multiplication by encrypted zero cst to a trivial encrypted zero + // tensor + class ZeroEncOpPattern : public mlir::OpRewritePattern { + public: + ZeroEncOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MulEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getA().getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + rewriter.replaceAllUsesWith(op, cstOp); + rewriter.eraseOp(op); + return mlir::success(); + } + }; + patterns.add(context); + patterns.add(context); +} + } // namespace FHE } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 17939d1b6..535303e47 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1470,8 +1470,58 @@ OpFoldResult MulEintIntOp::fold(FoldAdaptor operands) { return getOperand(0); } +void MulEintIntOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + + // Replace multiplication by clear zero cst to a trivial encrypted zero tensor + class ZeroCstOpPattern : public mlir::OpRewritePattern { + public: + ZeroCstOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MulEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getRhs().getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + auto vals = cstOp->getAttrOfType("value"); + for (auto it = vals.begin(); it != vals.end(); it++) { + if (*it != 0) { + return mlir::failure(); + } + } + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + + // Replace multiplication by encrypted zero cst to a trivial encrypted zero + // tensor + class ZeroEncOpPattern : public mlir::OpRewritePattern { + public: + ZeroEncOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MulEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getLhs().getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + rewriter.replaceAllUsesWith(op, cstOp); + rewriter.eraseOp(op); + return mlir::success(); + } + }; + patterns.add(context); + patterns.add(context); +} + /// Avoid multiplication with constant tensor of 1s OpFoldResult RoundOp::fold(FoldAdaptor operands) { + auto input = this->getInput(); auto inputType = this->getInput().getType().dyn_cast_or_null(); @@ -1489,6 +1539,69 @@ OpFoldResult RoundOp::fold(FoldAdaptor operands) { return nullptr; } +template +void getMatMulCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + + // Replace multiplication by clear zero cst to a trivial encrypted zero tensor + class ZeroCstOpPattern : public mlir::OpRewritePattern { + public: + ZeroCstOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MatMulOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = + op.getClearMatrix().template getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + auto vals = + cstOp->template getAttrOfType("value"); + for (auto it = vals.begin(); it != vals.end(); it++) { + if (*it != 0) { + return mlir::failure(); + } + } + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + + // Replace multiplication by encrypted zero cst to a trivial encrypted zero + // tensor + class ZeroEncOpPattern : public mlir::OpRewritePattern { + public: + ZeroEncOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(MatMulOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = + op.getEncryptedMatrix().template getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + patterns.add(context); + patterns.add(context); +} + +void MatMulIntEintOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + getMatMulCanonicalizationPatterns(patterns, context); +} + +void MatMulEintIntOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + getMatMulCanonicalizationPatterns(patterns, context); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 0c84d2e83..199f48b8e 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -379,8 +379,7 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso // ----- -func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> { - %0 = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>> +func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> tensor<4x3x!FHE.eint<7>> { // =============================== @@ -663,8 +662,7 @@ func.func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso // ----- -func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> { - %0 = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> +func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { // =============================== diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir index 228942eb8..c9643f16a 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir @@ -27,6 +27,26 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @mul_eint_int_zero(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func.func @mul_eint_int_zero(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() + // CHECK-NEXT: return %[[v0]] : !FHE.eint<2> + + %0 = arith.constant 0 : i3 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// CHECK-LABEL: func.func @mul_eint_zero_int(%arg0: i3) -> !FHE.eint<2> +func.func @mul_eint_zero_int(%arg0: i3) -> !FHE.eint<2> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() + // CHECK-NEXT: return %[[v0]] : !FHE.eint<2> + + %0 = "FHE.zero"() : () -> !FHE.eint<2> + %1 = "FHE.mul_eint_int"(%0, %arg0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + // CHECK-LABEL: func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { // CHECK-NEXT: return %arg0 : !FHE.eint<5> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir index 092a2ccef..d5d62e800 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir @@ -88,3 +88,57 @@ func.func @round(%arg0: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> { %1 = "FHELinalg.round"(%arg0) : (tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> return %1: tensor<4x!FHE.eint<5>> } + +// CHECK: func.func @mul_by_clear_zero(%[[a0:.*]]: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> +func.func @mul_by_clear_zero(%arg0: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() + // CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<5>> + %cst_0 = arith.constant dense<0> : tensor<4xi5> + %1 = "FHELinalg.mul_eint_int"(%arg0, %cst_0) : (tensor<4x!FHE.eint<5>>, tensor<4xi5>) -> tensor<4x!FHE.eint<5>> + return %1: tensor<4x!FHE.eint<5>> +} + +// CHECK: func.func @mul_by_encrypted_zero(%[[a0:.*]]: tensor<4xi5>) -> tensor<4x!FHE.eint<5>> +func.func @mul_by_encrypted_zero(%arg0: tensor<4xi5>) -> tensor<4x!FHE.eint<5>> { + // CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() + // CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<5>> + %cst0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<5>> + %1 = "FHELinalg.mul_eint_int"(%cst0, %arg0) : (tensor<4x!FHE.eint<5>>, tensor<4xi5>) -> tensor<4x!FHE.eint<5>> + return %1: tensor<4x!FHE.eint<5>> +} + +// CHECK: func.func @matmul_eint_int_clear_zero(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @matmul_eint_int_clear_zero(%x: tensor<4x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> { + %y = arith.constant dense<0> : tensor<3x2xi3> + %0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> + return %0 : tensor<4x2x!FHE.eint<2>> +} + +// CHECK: func.func @matmul_eint_int_encrypted_zero(%[[a0:.*]]: tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @matmul_eint_int_encrypted_zero(%y: tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> { + %x = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<2>> + %0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> + return %0 : tensor<4x2x!FHE.eint<2>> +} + +// CHECK: func.func @matmul_int_eint_clear_zero(%[[a0:.*]]: tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @matmul_int_eint_clear_zero(%y: tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> { + %x = arith.constant dense<0> : tensor<4x3xi3> + %0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> + return %0 : tensor<4x2x!FHE.eint<2>> +} + +// CHECK: func.func @matmul_int_eint_encrypted_zero(%[[a0:.*]]: tensor<4x3xi3>) -> tensor<4x2x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @matmul_int_eint_encrypted_zero(%x: tensor<4x3xi3>) -> tensor<4x2x!FHE.eint<2>> { + %y = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<2>> + %0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> + return %0 : tensor<4x2x!FHE.eint<2>> +}