From 1c202ebaff1fb4667a12bed5db958ad655838043 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 18 Oct 2021 14:38:42 +0200 Subject: [PATCH] enhance(compiler): Support of tensor operators in MANP pass (close #169) --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 125 +++++++++++++++++- .../Dialect/HLFHE/Analysis/MANP_tensor.mlir | 122 +++++++++++++++++ 2 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 17e106fa6..83a19cf5b 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -121,6 +122,16 @@ static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs, return lhs.zext(targetWidth) * rhs.zext(targetWidth); } +// Returns the maximum value beetwen `lhs` and `rhs`, where both values are +// assumed to be positive. The bit width of the smaller `APInt` is extended +// before comparison via `APInt::ult`. +static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) { + if (APIntWidthExtendULT(lhs, rhs)) { + return rhs; + } + return lhs; +} + // Calculates the square of `i`. The bit width `i` is extended in // order to guarantee that the product fits into the resulting // `APInt`. @@ -372,6 +383,58 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(sqNorm, eNorm); } +static llvm::APInt getSqMANP( + mlir::tensor::ExtractOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + return eNorm; +} + +static llvm::APInt getSqMANP( + mlir::tensor::FromElementsOp op, + llvm::ArrayRef *> operandMANPs) { + + auto max = std::max_element( + operandMANPs.begin(), operandMANPs.end(), + [](mlir::LatticeElement *const a, + mlir::LatticeElement *const b) { + return APIntWidthExtendULT(a->getValue().getMANP().getValue(), + b->getValue().getMANP().getValue()); + }); + return (*max)->getValue().getMANP().getValue(); +} + +static llvm::APInt getSqMANP( + mlir::tensor::ExtractSliceOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + return operandMANPs[0]->getValue().getMANP().getValue(); +} + +static llvm::APInt getSqMANP( + mlir::tensor::InsertSliceOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() >= 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(), + operandMANPs[1]->getValue().getMANP().getValue()); +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -387,6 +450,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { bool isDummy = false; llvm::APInt norm2SqEquiv; + // HLFHE Operaors if (auto dotOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(dotOp, operands); } else if (auto addEintIntOp = @@ -404,7 +468,58 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; - } else if (llvm::isa(op)) { + } + // Tensor Operators + // ExtractOp + else if (auto extractOp = llvm::dyn_cast(op)) { + if (extractOp.result() + .getType() + .isa()) { + norm2SqEquiv = getSqMANP(extractOp, operands); + } else { + isDummy = true; + } + } + // ExtractSliceOp + else if (auto extractSliceOp = + llvm::dyn_cast(op)) { + if (extractSliceOp.result() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(extractSliceOp, operands); + } else { + isDummy = true; + } + } + // InsertSliceOp + else if (auto insertSliceOp = + llvm::dyn_cast(op)) { + if (insertSliceOp.result() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(insertSliceOp, operands); + } else { + isDummy = true; + } + } + // FromElementOp + else if (auto fromOp = llvm::dyn_cast(op)) { + if (fromOp.result() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(fromOp, operands); + } else { + isDummy = true; + } + } + + else if (llvm::isa(op)) { isDummy = true; } else if (llvm::isa( *op->getDialect())) { @@ -488,6 +603,14 @@ protected: mlir::zamalang::HLFHE::EncryptedIntegerType eTy = res.getType() .dyn_cast_or_null(); + if (eTy == nullptr) { + auto tensorTy = res.getType().dyn_cast_or_null(); + if (tensorTy != nullptr) { + eTy = tensorTy.getElementType() + .dyn_cast_or_null< + mlir::zamalang::HLFHE::EncryptedIntegerType>(); + } + } if (eTy) { bool upd = false; diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir new file mode 100644 index 000000000..79281998b --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir @@ -0,0 +1,122 @@ +// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s + +func @tensor_from_elements_1(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>> +{ + // The MANP value is 1 as all operands are function arguments + // CHECK: %[[ret:.*]] = tensor.from_elements %[[a:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %0 = tensor.from_elements %a, %b, %c, %d : tensor<4x!HLFHE.eint<2>> + + return %0 : tensor<4x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_from_elements_2(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>> +{ + %cst = constant 3 : i3 + + // CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + %0 = "HLFHE.add_eint_int"(%a, %cst) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + + // The MANP value is 4, i.e. the max of all of its operands + // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %1 = tensor.from_elements %0, %b, %c, %d : tensor<4x!HLFHE.eint<2>> + + return %1 : tensor<4x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_extract_1(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2> +{ + %cst = constant 1 : index + + // The MANP value is 1 as the tensor operand is a function argument + // CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c1:.*]]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %0 = tensor.extract %t[%cst] : tensor<4x!HLFHE.eint<2>> + + return %0 : !HLFHE.eint<2> +} + +// ----- + +func @tensor_extract_2(%a: !HLFHE.eint<2>) -> !HLFHE.eint<2> +{ + %c1 = constant 1 : index + %c3 = constant 3 : i3 + + // CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + %0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>> + // CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %2 = tensor.extract %1[%c1] : tensor<4x!HLFHE.eint<2>> + + return %2 : !HLFHE.eint<2> +} + +// ----- + +func @tensor_extract_slice_1(%t: tensor<2x10x!HLFHE.eint<2>>) -> tensor<1x5x!HLFHE.eint<2>> +{ + // CHECK: %[[V0:.*]] = tensor.extract_slice %[[t:.*]][1, 5] [1, 5] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>> + %0 = tensor.extract_slice %t[1, 5] [1, 5] [1, 1] : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>> + + return %0 : tensor<1x5x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_extract_slice_2(%a: !HLFHE.eint<2>) -> tensor<2x!HLFHE.eint<2>> +{ + %c3 = constant 3 : i3 + + // CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + %0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>> + %1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>> + // CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>> + %2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>> + + return %2 : tensor<2x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_insert_slice_1(%t0: tensor<2x10x!HLFHE.eint<2>>, %t1: tensor<2x2x!HLFHE.eint<2>>) -> tensor<2x10x!HLFHE.eint<2>> +{ + // %[[V0:.*]] = tensor.insert_slice %[[t1:.*]] into %[[t0:.*]][0, 5] [2, 2] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>> + %0 = tensor.insert_slice %t1 into %t0[0, 5] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>> + + return %0 : tensor<2x10x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>> +{ + %c3 = constant 3 : i6 + %c6 = constant 6 : i6 + + // CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5> + %v0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5> + // CHECK: %[[V1:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5> + %v1 = "HLFHE.add_eint_int"(%a, %c6) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5> + + // CHECK: %[[T0:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]], %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<5>> + %t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!HLFHE.eint<5>> + + // CHECK: %[[T1:.*]] = tensor.from_elements %[[V1:.*]], %[[V1:.*]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> + %t1 = tensor.from_elements %v1, %v1 : tensor<2x!HLFHE.eint<5>> + + // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>> + %t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>> + + // CHECK: %[[T3:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> + %t3 = tensor.from_elements %v0, %v0 : tensor<2x!HLFHE.eint<5>> + + // CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>> + %t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>> + + return %t0 : tensor<4x!HLFHE.eint<5>> +}