feat: fold tensor ops with constant values

add/sub with constant tensors of 0s
mul with constant tensors of 1s
This commit is contained in:
youben11
2022-05-24 10:56:22 +01:00
committed by Ayoub Benaissa
parent 8aa6f3e809
commit 8d04c1e4af
3 changed files with 130 additions and 0 deletions

View File

@@ -73,6 +73,8 @@ def AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tensor
build($_builder, $_state, rhs.getType(), rhs, lhs);
}]>
];
let hasFolder = 1;
}
def AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
@@ -178,6 +180,8 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor
build($_builder, $_state, lhs.getType(), rhs, lhs);
}]>
];
let hasFolder = 1;
}
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
@@ -257,6 +261,8 @@ def MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tensor
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let hasFolder = 1;
}
def ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> {

View File

@@ -1641,6 +1641,48 @@ mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) {
return mlir::success();
}
// Avoid addition with constant tensor of 0s
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toAdd = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (toAdd == nullptr)
return nullptr;
for (int64_t i = 0; i < toAdd.size(); i++) {
llvm::APInt cst = toAdd.getFlatValue<llvm::APInt>(i);
if (cst != 0)
return nullptr;
}
return getOperand(0);
}
// Avoid subtraction with constant tensor of 0s
OpFoldResult SubIntEintOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toSub = operands[0].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (toSub == nullptr)
return nullptr;
for (int64_t i = 0; i < toSub.size(); i++) {
llvm::APInt cst = toSub.getFlatValue<llvm::APInt>(i);
if (cst != 0)
return nullptr;
}
return getOperand(1);
}
// Avoid multiplication with constant tensor of 1s
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toMul = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (toMul == nullptr)
return nullptr;
for (int64_t i = 0; i < toMul.size(); i++) {
llvm::APInt cst = toMul.getFlatValue<llvm::APInt>(i);
if (cst != 1)
return nullptr;
}
return getOperand(0);
}
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,82 @@
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
// CHECK: func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @add_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0]> : tensor<1xi3>
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @add_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
return %1: tensor<4x3x!FHE.eint<2>>
}
// CHECK: func @sub_int_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_int_eint_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<4xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @sub_int_eint_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_int_eint_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0]> : tensor<1xi3>
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @sub_int_eint_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_int_eint_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1x1xi3>, tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>>
return %1: tensor<4x3x!FHE.eint<2>>
}
// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @mul_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi3>
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @mul_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @mul_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[1]> : tensor<1xi3>
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @mul_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
// CHECK-NEXT: }
func @mul_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
%a1 = arith.constant dense<[[1]]> : tensor<1x1xi3>
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
return %1: tensor<4x3x!FHE.eint<2>>
}