diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 17e106fa6..7161d22b0 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +79,19 @@ protected: llvm::Optional manp; }; +// Checks if `lhs` is equal to `rhs`, where both values are assumed to +// be positive. The bit width of the smaller `APInt` is extended +// before comparison via `APInt::operator==`. +static bool APIntWidthExtendCompare(const llvm::APInt &lhs, + const llvm::APInt &rhs) { + if (lhs.getBitWidth() < rhs.getBitWidth()) + return lhs.zext(rhs.getBitWidth()) == rhs; + else if (lhs.getBitWidth() > rhs.getBitWidth()) + return lhs == rhs.zext(lhs.getBitWidth()); + else + return lhs == rhs; +} + // Checks if `lhs` is less than `rhs`, where both values are assumed // to be positive. The bit width of the smaller `APInt` is extended // before comparison via `APInt::ult`. @@ -305,6 +320,37 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUAdd(a, b); } +// Calculates the squared Minimal Arithmetic Noise Padding of a dot +// operation that is equivalent to an `tensor.extract` +// operation. Currently, this only supports extractions of elements +// from tensors passed as function arguments, for which the MANP is +// assumed to be 1. +static llvm::APInt getSqMANP( + mlir::tensor::ExtractOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::zamalang::HLFHE::EncryptedIntegerType elTy = + op.getOperand(0) + .getType() + .dyn_cast() + .getElementType() + .dyn_cast_or_null(); + + assert(elTy && "Can only calculate MANP for tensor.extract operations on " + "HLFHE.eint tensors"); + + assert(operandMANPs.size() >= 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "MANP value for tensor is unknown"); + + llvm::APInt one{1, 1, false}; + + assert(APIntWidthExtendCompare( + operandMANPs[0]->getValue().getMANP().getValue(), one) && + "MANP value for tensor is not 1 as expected"); + + return one; +} + // Calculates the squared Minimal Arithmetic Noise Padding of a dot operation // that is equivalent to an `HLFHE.sub_int_eint` operation. static llvm::APInt getSqMANP( @@ -404,6 +450,17 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; + } else if (auto tensorExtractOp = + llvm::dyn_cast(op)) { + // Only handle extract operations that produce an encrypted integer + if (tensorExtractOp->getResultTypes() + .front() + .dyn_cast_or_null< + mlir::zamalang::HLFHE::EncryptedIntegerType>()) { + norm2SqEquiv = getSqMANP(tensorExtractOp, operands); + } else { + isDummy = true; + } } else if (llvm::isa(op)) { isDummy = true; } else if (llvm::isa(