diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 072cf2d46..f2a913ab9 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -899,6 +899,18 @@ static llvm::APInt getSqMANP( return accNorm; } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::TransposeOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + return operandMANPs[0]->getValue().getMANP().getValue(); +} + static llvm::APInt getSqMANP( mlir::tensor::ExtractOp op, llvm::ArrayRef *> operandMANPs) { @@ -1220,7 +1232,20 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(conv2dOp, operands); + } else if (auto transposeOp = + llvm::dyn_cast( + op)) { + if (transposeOp.tensor() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(transposeOp, operands); + } else { + isDummy = true; + } } + // Tensor Operators // ExtractOp else if (auto extractOp = llvm::dyn_cast(op)) { diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir index ec5243deb..6d33c49e3 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir @@ -209,4 +209,28 @@ func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> %1 = "FHE.neg_eint"(%0) : (!FHE.eint<2>) -> !FHE.eint<2> return %1 : !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> +func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {MANP = 1 : ui1} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + return %c : tensor<5x4x3x!FHE.eint<6>> +} + +// ----- + +// CHECK-LABEL: @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> +func @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.add_eint"(%arg0, %arg1) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>> + // CHECK-NEXT: %[[v1:.*]] = "FHELinalg.transpose"(%[[v0]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: return %[[v1]] : tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: } + %sum = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>> + %c = "FHELinalg.transpose"(%sum) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + return %c : tensor<5x4x3x!FHE.eint<6>> } \ No newline at end of file