enhance(compiler): Support of tensor operators in MANP pass (close #169)

This commit is contained in:
Quentin Bourgerie
2021-10-18 14:38:42 +02:00
parent fcc992db2b
commit 1c202ebaff
2 changed files with 246 additions and 1 deletions

View File

@@ -11,6 +11,7 @@
#include <llvm/ADT/SmallString.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/Pass/PassManager.h>
@@ -121,6 +122,16 @@ static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
return lhs.zext(targetWidth) * rhs.zext(targetWidth);
}
// Returns the maximum value beetwen `lhs` and `rhs`, where both values are
// assumed to be positive. The bit width of the smaller `APInt` is extended
// before comparison via `APInt::ult`.
static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
if (APIntWidthExtendULT(lhs, rhs)) {
return rhs;
}
return lhs;
}
// Calculates the square of `i`. The bit width `i` is extended in
// order to guarantee that the product fits into the resulting
// `APInt`.
@@ -372,6 +383,58 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUMul(sqNorm, eNorm);
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
return eNorm;
}
static llvm::APInt getSqMANP(
mlir::tensor::FromElementsOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
auto max = std::max_element(
operandMANPs.begin(), operandMANPs.end(),
[](mlir::LatticeElement<MANPLatticeValue> *const a,
mlir::LatticeElement<MANPLatticeValue> *const b) {
return APIntWidthExtendULT(a->getValue().getMANP().getValue(),
b->getValue().getMANP().getValue());
});
return (*max)->getValue().getMANP().getValue();
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractSliceOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
}
static llvm::APInt getSqMANP(
mlir::tensor::InsertSliceOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() >= 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
operandMANPs[1]->getValue().getMANP().getValue());
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -387,6 +450,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
bool isDummy = false;
llvm::APInt norm2SqEquiv;
// HLFHE Operaors
if (auto dotOp = llvm::dyn_cast<mlir::zamalang::HLFHE::Dot>(op)) {
norm2SqEquiv = getSqMANP(dotOp, operands);
} else if (auto addEintIntOp =
@@ -404,7 +468,58 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
} else if (llvm::isa<mlir::ConstantOp>(op)) {
}
// Tensor Operators
// ExtractOp
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
if (extractOp.result()
.getType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(extractOp, operands);
} else {
isDummy = true;
}
}
// ExtractSliceOp
else if (auto extractSliceOp =
llvm::dyn_cast<mlir::tensor::ExtractSliceOp>(op)) {
if (extractSliceOp.result()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(extractSliceOp, operands);
} else {
isDummy = true;
}
}
// InsertSliceOp
else if (auto insertSliceOp =
llvm::dyn_cast<mlir::tensor::InsertSliceOp>(op)) {
if (insertSliceOp.result()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(insertSliceOp, operands);
} else {
isDummy = true;
}
}
// FromElementOp
else if (auto fromOp = llvm::dyn_cast<mlir::tensor::FromElementsOp>(op)) {
if (fromOp.result()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(fromOp, operands);
} else {
isDummy = true;
}
}
else if (llvm::isa<mlir::ConstantOp>(op)) {
isDummy = true;
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(
*op->getDialect())) {
@@ -488,6 +603,14 @@ protected:
mlir::zamalang::HLFHE::EncryptedIntegerType eTy =
res.getType()
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (eTy == nullptr) {
auto tensorTy = res.getType().dyn_cast_or_null<mlir::TensorType>();
if (tensorTy != nullptr) {
eTy = tensorTy.getElementType()
.dyn_cast_or_null<
mlir::zamalang::HLFHE::EncryptedIntegerType>();
}
}
if (eTy) {
bool upd = false;

View File

@@ -0,0 +1,122 @@
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
func @tensor_from_elements_1(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>>
{
// The MANP value is 1 as all operands are function arguments
// CHECK: %[[ret:.*]] = tensor.from_elements %[[a:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%0 = tensor.from_elements %a, %b, %c, %d : tensor<4x!HLFHE.eint<2>>
return %0 : tensor<4x!HLFHE.eint<2>>
}
// -----
func @tensor_from_elements_2(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>>
{
%cst = constant 3 : i3
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%0 = "HLFHE.add_eint_int"(%a, %cst) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// The MANP value is 4, i.e. the max of all of its operands
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%1 = tensor.from_elements %0, %b, %c, %d : tensor<4x!HLFHE.eint<2>>
return %1 : tensor<4x!HLFHE.eint<2>>
}
// -----
func @tensor_extract_1(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
{
%cst = constant 1 : index
// The MANP value is 1 as the tensor operand is a function argument
// CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c1:.*]]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%0 = tensor.extract %t[%cst] : tensor<4x!HLFHE.eint<2>>
return %0 : !HLFHE.eint<2>
}
// -----
func @tensor_extract_2(%a: !HLFHE.eint<2>) -> !HLFHE.eint<2>
{
%c1 = constant 1 : index
%c3 = constant 3 : i3
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>>
// CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%2 = tensor.extract %1[%c1] : tensor<4x!HLFHE.eint<2>>
return %2 : !HLFHE.eint<2>
}
// -----
func @tensor_extract_slice_1(%t: tensor<2x10x!HLFHE.eint<2>>) -> tensor<1x5x!HLFHE.eint<2>>
{
// CHECK: %[[V0:.*]] = tensor.extract_slice %[[t:.*]][1, 5] [1, 5] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>>
%0 = tensor.extract_slice %t[1, 5] [1, 5] [1, 1] : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>>
return %0 : tensor<1x5x!HLFHE.eint<2>>
}
// -----
func @tensor_extract_slice_2(%a: !HLFHE.eint<2>) -> tensor<2x!HLFHE.eint<2>>
{
%c3 = constant 3 : i3
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>>
// CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>>
%2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>>
return %2 : tensor<2x!HLFHE.eint<2>>
}
// -----
func @tensor_insert_slice_1(%t0: tensor<2x10x!HLFHE.eint<2>>, %t1: tensor<2x2x!HLFHE.eint<2>>) -> tensor<2x10x!HLFHE.eint<2>>
{
// %[[V0:.*]] = tensor.insert_slice %[[t1:.*]] into %[[t0:.*]][0, 5] [2, 2] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>>
%0 = tensor.insert_slice %t1 into %t0[0, 5] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>>
return %0 : tensor<2x10x!HLFHE.eint<2>>
}
// -----
func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>>
{
%c3 = constant 3 : i6
%c6 = constant 6 : i6
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
%v0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
// CHECK: %[[V1:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
%v1 = "HLFHE.add_eint_int"(%a, %c6) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
// CHECK: %[[T0:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]], %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<5>>
%t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!HLFHE.eint<5>>
// CHECK: %[[T1:.*]] = tensor.from_elements %[[V1:.*]], %[[V1:.*]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>>
%t1 = tensor.from_elements %v1, %v1 : tensor<2x!HLFHE.eint<5>>
// CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
%t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
// CHECK: %[[T3:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>>
%t3 = tensor.from_elements %v0, %v0 : tensor<2x!HLFHE.eint<5>>
// CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
%t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
return %t0 : tensor<4x!HLFHE.eint<5>>
}