mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
feat: support FHELinalg.transpose in MANP
This commit is contained in:
@@ -899,6 +899,18 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::TransposeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -1220,7 +1232,20 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(conv2dOp, operands);
|
||||
} else if (auto transposeOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::TransposeOp>(
|
||||
op)) {
|
||||
if (transposeOp.tensor()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(transposeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
|
||||
@@ -209,4 +209,28 @@ func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%1 = "FHE.neg_eint"(%0) : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
return %1 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {MANP = 1 : ui1} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
func @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.add_eint"(%arg0, %arg1) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = "FHELinalg.transpose"(%[[v0]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: return %[[v1]] : tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
%sum = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>>
|
||||
%c = "FHELinalg.transpose"(%sum) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<6>>
|
||||
}
|
||||
Reference in New Issue
Block a user