mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: fold tensor ops with constant values
add/sub with constant tensors of 0s mul with constant tensors of 1s
This commit is contained in:
@@ -73,6 +73,8 @@ def AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tensor
|
||||
build($_builder, $_state, rhs.getType(), rhs, lhs);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
|
||||
@@ -178,6 +180,8 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor
|
||||
build($_builder, $_state, lhs.getType(), rhs, lhs);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
@@ -257,6 +261,8 @@ def MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tensor
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> {
|
||||
|
||||
@@ -1641,6 +1641,48 @@ mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Avoid addition with constant tensor of 0s
|
||||
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toAdd = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (toAdd == nullptr)
|
||||
return nullptr;
|
||||
for (int64_t i = 0; i < toAdd.size(); i++) {
|
||||
llvm::APInt cst = toAdd.getFlatValue<llvm::APInt>(i);
|
||||
if (cst != 0)
|
||||
return nullptr;
|
||||
}
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant tensor of 0s
|
||||
OpFoldResult SubIntEintOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[0].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (toSub == nullptr)
|
||||
return nullptr;
|
||||
for (int64_t i = 0; i < toSub.size(); i++) {
|
||||
llvm::APInt cst = toSub.getFlatValue<llvm::APInt>(i);
|
||||
if (cst != 0)
|
||||
return nullptr;
|
||||
}
|
||||
return getOperand(1);
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant tensor of 1s
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toMul = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (toMul == nullptr)
|
||||
return nullptr;
|
||||
for (int64_t i = 0; i < toMul.size(); i++) {
|
||||
llvm::APInt cst = toMul.getFlatValue<llvm::APInt>(i);
|
||||
if (cst != 1)
|
||||
return nullptr;
|
||||
}
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
82
compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir
Normal file
82
compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir
Normal file
@@ -0,0 +1,82 @@
|
||||
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
|
||||
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @add_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0]> : tensor<1xi3>
|
||||
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @add_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
|
||||
%1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_int_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_int_eint_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
|
||||
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<4xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_int_eint_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_int_eint_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0]> : tensor<1xi3>
|
||||
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_int_eint_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_int_eint_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
|
||||
%1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1x1xi3>, tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>>
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @mul_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi3>
|
||||
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @mul_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @mul_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[1]> : tensor<1xi3>
|
||||
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @mul_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @mul_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[[1]]> : tensor<1x1xi3>
|
||||
%1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user