mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(compiler): fold mul and matmul by zero to zero
That will close https://github.com/zama-ai/concrete-internal/issues/297 also for dag-multi optimization
This commit is contained in:
committed by
Quentin Bourgerie
parent
3561b51329
commit
f487432207
@@ -283,6 +283,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> {
|
||||
|
||||
@@ -373,6 +373,8 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
|
||||
@@ -748,6 +750,13 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEin
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getClearMatrix() { return getRhs(); }
|
||||
mlir::Value getEncryptedMatrix () { return getLhs(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
@@ -886,6 +895,13 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getClearMatrix() { return getLhs(); }
|
||||
mlir::Value getEncryptedMatrix () { return getRhs(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
@@ -345,6 +347,53 @@ OpFoldResult MulEintIntOp::fold(FoldAdaptor operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void MulEintIntOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
|
||||
// Replace multiplication by clear zero cst to a trivial encrypted zero tensor
|
||||
class ZeroCstOpPattern : public mlir::OpRewritePattern<MulEintIntOp> {
|
||||
public:
|
||||
ZeroCstOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MulEintIntOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getB().getDefiningOp<arith::ConstantOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
auto val = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
if (val.getInt() != 0) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroEintOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Replace multiplication by encrypted zero cst to a trivial encrypted zero
|
||||
// tensor
|
||||
class ZeroEncOpPattern : public mlir::OpRewritePattern<MulEintIntOp> {
|
||||
public:
|
||||
ZeroEncOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MulEintIntOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getA().getDefiningOp<FHE::ZeroEintOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
rewriter.replaceAllUsesWith(op, cstOp);
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<ZeroCstOpPattern>(context);
|
||||
patterns.add<ZeroEncOpPattern>(context);
|
||||
}
|
||||
|
||||
} // namespace FHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1470,8 +1470,58 @@ OpFoldResult MulEintIntOp::fold(FoldAdaptor operands) {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
void MulEintIntOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
|
||||
// Replace multiplication by clear zero cst to a trivial encrypted zero tensor
|
||||
class ZeroCstOpPattern : public mlir::OpRewritePattern<MulEintIntOp> {
|
||||
public:
|
||||
ZeroCstOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MulEintIntOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getRhs().getDefiningOp<arith::ConstantOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
auto vals = cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
for (auto it = vals.begin(); it != vals.end(); it++) {
|
||||
if (*it != 0) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroTensorOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Replace multiplication by encrypted zero cst to a trivial encrypted zero
|
||||
// tensor
|
||||
class ZeroEncOpPattern : public mlir::OpRewritePattern<MulEintIntOp> {
|
||||
public:
|
||||
ZeroEncOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MulEintIntOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getLhs().getDefiningOp<FHE::ZeroTensorOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
rewriter.replaceAllUsesWith(op, cstOp);
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<ZeroCstOpPattern>(context);
|
||||
patterns.add<ZeroEncOpPattern>(context);
|
||||
}
|
||||
|
||||
/// Avoid multiplication with constant tensor of 1s
|
||||
OpFoldResult RoundOp::fold(FoldAdaptor operands) {
|
||||
|
||||
auto input = this->getInput();
|
||||
auto inputType =
|
||||
this->getInput().getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
@@ -1489,6 +1539,69 @@ OpFoldResult RoundOp::fold(FoldAdaptor operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename MatMulOp>
|
||||
void getMatMulCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
|
||||
// Replace multiplication by clear zero cst to a trivial encrypted zero tensor
|
||||
class ZeroCstOpPattern : public mlir::OpRewritePattern<MatMulOp> {
|
||||
public:
|
||||
ZeroCstOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MatMulOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MatMulOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp =
|
||||
op.getClearMatrix().template getDefiningOp<arith::ConstantOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
auto vals =
|
||||
cstOp->template getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
for (auto it = vals.begin(); it != vals.end(); it++) {
|
||||
if (*it != 0) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroTensorOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Replace multiplication by encrypted zero cst to a trivial encrypted zero
|
||||
// tensor
|
||||
class ZeroEncOpPattern : public mlir::OpRewritePattern<MatMulOp> {
|
||||
public:
|
||||
ZeroEncOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<MatMulOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MatMulOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp =
|
||||
op.getEncryptedMatrix().template getDefiningOp<FHE::ZeroTensorOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroTensorOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<ZeroCstOpPattern>(context);
|
||||
patterns.add<ZeroEncOpPattern>(context);
|
||||
}
|
||||
|
||||
void MatMulIntEintOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
getMatMulCanonicalizationPatterns<MatMulIntEintOp>(patterns, context);
|
||||
}
|
||||
|
||||
void MatMulEintIntOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
getMatMulCanonicalizationPatterns<MatMulEintIntOp>(patterns, context);
|
||||
}
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -379,8 +379,7 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> {
|
||||
%0 = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>>
|
||||
func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> tensor<4x3x!FHE.eint<7>> {
|
||||
|
||||
// ===============================
|
||||
|
||||
@@ -663,8 +662,7 @@ func.func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_int_eint_cst() -> tensor<3x2x!FHE.eint<7>> {
|
||||
%0 = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>>
|
||||
func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> {
|
||||
|
||||
// ===============================
|
||||
|
||||
|
||||
@@ -27,6 +27,26 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int_zero(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func.func @mul_eint_int_zero(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 0 : i3
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_zero_int(%arg0: i3) -> !FHE.eint<2>
|
||||
func.func @mul_eint_zero_int(%arg0: i3) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<2>
|
||||
|
||||
%0 = "FHE.zero"() : () -> !FHE.eint<2>
|
||||
%1 = "FHE.mul_eint_int"(%0, %arg0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5>
|
||||
func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<5>
|
||||
|
||||
@@ -88,3 +88,57 @@ func.func @round(%arg0: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
|
||||
%1 = "FHELinalg.round"(%arg0) : (tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
|
||||
return %1: tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @mul_by_clear_zero(%[[a0:.*]]: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
|
||||
func.func @mul_by_clear_zero(%arg0: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<5>>
|
||||
%cst_0 = arith.constant dense<0> : tensor<4xi5>
|
||||
%1 = "FHELinalg.mul_eint_int"(%arg0, %cst_0) : (tensor<4x!FHE.eint<5>>, tensor<4xi5>) -> tensor<4x!FHE.eint<5>>
|
||||
return %1: tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @mul_by_encrypted_zero(%[[a0:.*]]: tensor<4xi5>) -> tensor<4x!FHE.eint<5>>
|
||||
func.func @mul_by_encrypted_zero(%arg0: tensor<4xi5>) -> tensor<4x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<5>>
|
||||
%cst0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<5>>
|
||||
%1 = "FHELinalg.mul_eint_int"(%cst0, %arg0) : (tensor<4x!FHE.eint<5>>, tensor<4xi5>) -> tensor<4x!FHE.eint<5>>
|
||||
return %1: tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @matmul_eint_int_clear_zero(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @matmul_eint_int_clear_zero(%x: tensor<4x3x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
%y = arith.constant dense<0> : tensor<3x2xi3>
|
||||
%0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @matmul_eint_int_encrypted_zero(%[[a0:.*]]: tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @matmul_eint_int_encrypted_zero(%y: tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
%x = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<4x3x!FHE.eint<2>>, tensor<3x2xi3>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @matmul_int_eint_clear_zero(%[[a0:.*]]: tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @matmul_int_eint_clear_zero(%y: tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
%x = arith.constant dense<0> : tensor<4x3xi3>
|
||||
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @matmul_int_eint_encrypted_zero(%[[a0:.*]]: tensor<4x3xi3>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @matmul_int_eint_encrypted_zero(%x: tensor<4x3xi3>) -> tensor<4x2x!FHE.eint<2>> {
|
||||
%y = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user