feat: support FHELinalg.transpose in MANP

This commit is contained in:
youben11
2022-03-22 09:33:12 +01:00
committed by Ayoub Benaissa
parent 3a4723a0b8
commit 4e64b9e12a
2 changed files with 49 additions and 0 deletions

View File

@@ -899,6 +899,18 @@ static llvm::APInt getSqMANP(
return accNorm;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::TransposeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
@@ -1220,7 +1232,20 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
op)) {
norm2SqEquiv = getSqMANP(conv2dOp, operands);
} else if (auto transposeOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
op)) {
if (transposeOp.tensor()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(transposeOp, operands);
} else {
isDummy = true;
}
}
// Tensor Operators
// ExtractOp
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {

View File

@@ -209,4 +209,28 @@ func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
%1 = "FHE.neg_eint"(%0) : (!FHE.eint<2>) -> !FHE.eint<2>
return %1 : !FHE.eint<2>
}
// -----
// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {MANP = 1 : ui1} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}
// -----
// CHECK-LABEL: @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
func @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.add_eint"(%arg0, %arg1) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = "FHELinalg.transpose"(%[[v0]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: }
%sum = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>>
%c = "FHELinalg.transpose"(%sum) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}