feat(compiler): Add support for tensor.extract operations in MANP pass

Add support for `tensor.extract` operations in the MANP pass. This
currently only supports extract operations on tensors of encrypted
integers, which are passed as function arguments, e.g.:

 func @extract_ith(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
   %c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
   return %c : !HLFHE.eint<5>
 }
This commit is contained in:
Andi Drebes
2021-10-15 14:35:08 +02:00
parent 941465060e
commit 0423a05db8

View File

@@ -1,4 +1,5 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
@@ -11,6 +12,7 @@
#include <llvm/ADT/SmallString.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/Pass/PassManager.h>
@@ -77,6 +79,19 @@ protected:
llvm::Optional<llvm::APInt> manp;
};
// Checks if `lhs` is equal to `rhs`, where both values are assumed to
// be positive. The bit width of the smaller `APInt` is extended
// before comparison via `APInt::operator==`.
static bool APIntWidthExtendCompare(const llvm::APInt &lhs,
const llvm::APInt &rhs) {
if (lhs.getBitWidth() < rhs.getBitWidth())
return lhs.zext(rhs.getBitWidth()) == rhs;
else if (lhs.getBitWidth() > rhs.getBitWidth())
return lhs == rhs.zext(lhs.getBitWidth());
else
return lhs == rhs;
}
// Checks if `lhs` is less than `rhs`, where both values are assumed
// to be positive. The bit width of the smaller `APInt` is extended
// before comparison via `APInt::ult`.
@@ -305,6 +320,37 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUAdd(a, b);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot
// operation that is equivalent to an `tensor.extract`
// operation. Currently, this only supports extractions of elements
// from tensors passed as function arguments, for which the MANP is
// assumed to be 1.
static llvm::APInt getSqMANP(
mlir::tensor::ExtractOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::zamalang::HLFHE::EncryptedIntegerType elTy =
op.getOperand(0)
.getType()
.dyn_cast<mlir::TensorType>()
.getElementType()
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
assert(elTy && "Can only calculate MANP for tensor.extract operations on "
"HLFHE.eint tensors");
assert(operandMANPs.size() >= 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"MANP value for tensor is unknown");
llvm::APInt one{1, 1, false};
assert(APIntWidthExtendCompare(
operandMANPs[0]->getValue().getMANP().getValue(), one) &&
"MANP value for tensor is not 1 as expected");
return one;
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.sub_int_eint` operation.
static llvm::APInt getSqMANP(
@@ -404,6 +450,17 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
} else if (auto tensorExtractOp =
llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
// Only handle extract operations that produce an encrypted integer
if (tensorExtractOp->getResultTypes()
.front()
.dyn_cast_or_null<
mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(tensorExtractOp, operands);
} else {
isDummy = true;
}
} else if (llvm::isa<mlir::ConstantOp>(op)) {
isDummy = true;
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(