// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. #include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { /// Returns `true` if the given value is a scalar or tensor argument of /// a function, for which a MANP of 1 can be assumed. static bool isEncryptedFunctionParameter(mlir::Value value) { if (!value.isa()) return false; mlir::Block *block = value.cast().getOwner(); if (!block || !block->getParentOp() || !llvm::isa(block->getParentOp())) { return false; } return mlir::concretelang::fhe::utils::isEncryptedValue(value); } /// The `MANPLatticeValue` represents the squared Minimal Arithmetic /// Noise Padding for an operation using the squared 2-norm of an /// equivalent dot operation. This can either be an actual value if the /// values for its predecessors have been calculated beforehand or an /// unknown value otherwise. struct MANPLatticeValue { MANPLatticeValue(std::optional manp = {}) : manp(manp) {} static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) { return MANPLatticeValue(); } static MANPLatticeValue getPessimisticValueState(mlir::Value value) { // Function arguments are assumed to require a Minimal Arithmetic // Noise Padding with a 2-norm of 1. // // TODO: Provide a mechanism to propagate Minimal Arithmetic Noise // Padding across function calls. if (isEncryptedFunctionParameter(value)) { return MANPLatticeValue(llvm::APInt{1, 1, false}); } else { // All other operations have an unknown Minimal Arithmetic Noise // Padding until an value for all predecessors has been // calculated. return MANPLatticeValue(); } } bool operator==(const MANPLatticeValue &rhs) const { return this->manp == rhs.manp; } static MANPLatticeValue join(const MANPLatticeValue &lhs, const MANPLatticeValue &rhs) { if (!lhs.getMANP().has_value()) return rhs; if (!rhs.getMANP().has_value()) return lhs; if (lhs.getMANP().value() == rhs.getMANP().value()) return lhs; assert(false && "Attempting to join two distinct initialized values"); return MANPLatticeValue{}; } void print(raw_ostream &os) const { if (manp.has_value()) os << manp.value(); else os << "(undefined)"; } std::optional getMANP() const { return manp; } protected: std::optional manp; }; class MANPLattice : public mlir::dataflow::Lattice { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MANPLattice) using Lattice::Lattice; }; /// Checks if `lhs` is less than `rhs`, where both values are assumed /// to be positive. The bit width of the smaller `APInt` is extended /// before comparison via `APInt::ult`. static bool APIntWidthExtendULT(const llvm::APInt &lhs, const llvm::APInt &rhs) { if (lhs.getBitWidth() < rhs.getBitWidth()) return lhs.zext(rhs.getBitWidth()).ult(rhs); else if (lhs.getBitWidth() > rhs.getBitWidth()) return lhs.ult(rhs.zext(lhs.getBitWidth())); else return lhs.ult(rhs); } /// Adds two `APInt` values, where both values are assumed to be /// positive. The bit width of the operands is extended in order to /// guarantee that the sum fits into the resulting `APInt`. static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs, const llvm::APInt &rhs) { unsigned maxBits = std::max(lhs.getBitWidth(), rhs.getBitWidth()); // Make sure the required number of bits can be represented by the // `unsigned` argument of `zext`. assert(std::numeric_limits::max() - maxBits > 1); unsigned targetWidth = maxBits + 1; return lhs.zext(targetWidth) + rhs.zext(targetWidth); } /// Multiplies two `APInt` values, where both values are assumed to be /// positive. The bit width of the operands is extended in order to /// guarantee that the product fits into the resulting `APInt`. static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs, const llvm::APInt &rhs) { // Make sure the required number of bits can be represented by the // `unsigned` argument of `zext`. assert(std::numeric_limits::max() - std::max(lhs.getBitWidth(), rhs.getBitWidth()) > std::min(lhs.getBitWidth(), rhs.getBitWidth()) && "Required number of bits cannot be represented with an APInt"); unsigned targetWidth = lhs.getBitWidth() + rhs.getBitWidth(); return lhs.zext(targetWidth) * rhs.zext(targetWidth); } /// Returns the maximum value between `lhs` and `rhs`, where both values are /// assumed to be positive. The bit width of the smaller `APInt` is extended /// before comparison via `APInt::ult`. static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) { if (APIntWidthExtendULT(lhs, rhs)) { return rhs; } return lhs; } /// Calculates the square of `i`. The bit width `i` is extended in /// order to guarantee that the product fits into the resulting /// `APInt`. static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { // Make sure the required number of bits can be represented by the // `unsigned` argument of `zext`. assert(i.getBitWidth() < std::numeric_limits::max() / 2 && "Required number of bits cannot be represented with an APInt"); llvm::APInt ie = i.zext(2 * i.getBitWidth()); return ie * ie; } /// Calculates the square of the value of `i`. static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) { auto extI = i.sext(2 * i.getBitWidth()); return extI * extI; } /// Calculates the square root of `i` and rounds it to the next highest /// integer value (i.e., the square of the result is guaranteed to be /// greater or equal to `i`). static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) { llvm::APInt res = i.sqrt(); llvm::APInt resSq = APIntWidthExtendUnsignedSq(res); if (APIntWidthExtendULT(resSq, i)) return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false}); else return res; } /// Returns a string representation of `i` assuming that `i` is an /// unsigned value. static std::string APIntToStringValUnsigned(const llvm::APInt &i) { llvm::SmallString<32> s; i.toStringUnsigned(s); return std::string(s.c_str()); } /// Returns the squared 2-norm for a dynamic integer by conservatively /// assuming that the integer's value is the maximum for the integer /// width. static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { assert(t.isSignlessInteger() && "Type must be a signless integer type"); assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); // we consider the maximum value as a signed integer llvm::APInt maxVal = APInt::getMaxValue(t.getIntOrFloatBitWidth() - 1); return APIntWidthExtendUnsignedSq(maxVal); } static llvm::APInt getNoOpSqMANP(llvm::ArrayRef operandMANPs) { // Come from block arg as example if (operandMANPs.size() == 0) { return llvm::APInt{1, 1, false}; } assert(operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); return eNorm; } /// Calculates the squared Minimal Arithmetic Noise Padding of an /// `FHELinalg.dot_eint_eint` operation. static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::DotEint op, llvm::ArrayRef operandMANPs) { assert(operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().has_value() && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value(); llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value(); auto rhsType = ((mlir::Type)op.getRhs().getType()).cast(); llvm::ArrayRef rhsShape = rhsType.getShape(); int64_t rhsDims = (int64_t)rhsShape.size(); assert(rhsDims == 1 && "In MANP computation dot product RHS expected to have " "a single dimension"); int64_t N = rhsShape[0]; // Compute output MANP: // Tlu output MANP is 1 llvm::APInt tlu = {1, 1, false}; // The element-wise multiplication is given by the // subtraction of two TLU outputs. The MANP of the multiplication is thus // the sum of the TLU MANPs llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu); llvm::APInt accNorm = llvm::APInt{1, 0, false}; return APIntWidthExtendUMul(APIntWidthExtendUAdd(tlu, tlu), llvm::APInt(ceilLog2(N + 1), N, false)); } /// Calculates the squared Minimal Arithmetic Noise Padding of an unary FHE /// operation. static std::optional getSqMANP(mlir::concretelang::FHE::UnaryEint op, llvm::ArrayRef operandMANPs) { // not all unary ops taking an encrypted operand have a type signature // reflecting that, a check might be required (FHELinalg.TransposeOp is one // such known op) if (op.operandIntType().isa()) { assert(operandMANPs.size() == 1 && operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operand"); return op.sqMANP(operandMANPs[0]->getValue().getMANP().value()); } else return {}; } /// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE /// operation with the first operand encrypted. static std::optional getSqMANP(mlir::concretelang::FHE::BinaryEintInt op, llvm::ArrayRef operandMANPs) { assert(operandMANPs.size() >= 2 && // conv2d has an optional 3rd operand operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operand"); return op.sqMANP(operandMANPs[0]->getValue().getMANP().value()); } /// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE /// operation with the second operand encrypted. static std::optional getSqMANP(mlir::concretelang::FHE::BinaryIntEint op, llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() == 2 && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return op.sqMANP(operandMANPs[1]->getValue().getMANP().value()); } /// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE /// operation with both operands encrypted. static std::optional getSqMANP(mlir::concretelang::FHE::BinaryEint op, llvm::ArrayRef operandMANPs) { assert(operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().has_value() && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); return op.sqMANP(operandMANPs[0]->getValue().getMANP().value(), operandMANPs[1]->getValue().getMANP().value()); } static llvm::APInt sqMANP_mul_eint_int(llvm::APInt a, mlir::Type iTy, std::optional b) { assert(iTy.isSignlessInteger() && "Only multiplications with signless integers are currently allowed"); llvm::APInt sqNorm; if (b.has_value()) { // For a constant operand use actual constant to calculate 2-norm sqNorm = APIntWidthExtendSqForConstant(b.value()); } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width sqNorm = conservativeIntNorm2Sq(iTy); } return APIntWidthExtendUMul(sqNorm, a); } static llvm::APInt sqMANP_mul_eint(llvm::APInt a, llvm::APInt b) { // a * b = ((a + b)^2 / 4) - ((a - b)^2 / 4) == tlu(a + b) - tlu(a - b) const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(a, b); const llvm::APInt tlu = {1, 1, false}; const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); return result; } /// Computes the squared vector norm as the maximum of all dot products over the /// destroyed dimension. The computation is recursive and can be seen as folding /// a tree of values where leaves compute the dot products and nodes choose the /// maximum of the children. There is a branching node at every shape dimension /// with the fanout equal to that dimension range, except for the destroyed /// dimension (where the fanout is 1). /// /// This function is expected to behave correctly on clear tensor shapes with /// any dimensionality, for example: /// /// MxN @ N -> M /// LxMxN @ N -> LxM /// KxLxMxN @ N -> KxLxM /// /// N @ NxP -> P /// N @ LxNxP -> LxP /// N @ KxLxNxP -> KxLxP /// /// KxLxMxN @ NxP -> KxLxMxP /// KxLxMxN @ LxNxP -> KxLxMxP /// Kx1xMxN @ LxNxP -> KxLxMxP /// /// MxN @ KxLxNxP -> KxLxMxP /// LxMxN @ KxLxNxP -> KxLxMxP /// 1xMxN @ KxLxNxP -> KxLxMxP /// /// N @ NxP -> P /// N @ LxNxP -> LxP /// N @ KxLxNxP -> KxLxP /// /// MxN @ N -> M /// LxMxN @ N -> LxM /// KxLxMxN @ N -> KxLxM /// /// MxN @ NxP -> MxP static llvm::APInt sqMANP_matmul_internal( llvm::ArrayRef shape, size_t destroyedDimension, llvm::SmallVector iterPoint, mlir::detail::ElementsAttrRange clearValues, llvm::APInt encryptedOperandNorm) { assert(iterPoint.size() >= shape.size() && "Tensor shape dimensionality is larger than iteration space " "dimensionality"); assert(destroyedDimension < iterPoint.size() && "Destroyed dimension outside of iteration space dimensionality"); size_t currentDimension = iterPoint.size() - shape.size(); if (currentDimension == destroyedDimension) { // the dot product over destroyed dimension will sum products counting down // from the largest index iterPoint[currentDimension] = shape[0] - 1; return sqMANP_matmul_internal(shape.drop_front(1), destroyedDimension, iterPoint, clearValues, encryptedOperandNorm); } if (shape.size() == 0) { // `iterPoint` is defined in all indices, let's // compute the dot product llvm::APInt accumulationNorm = llvm::APInt{1, 0, false}; for (int64_t i = iterPoint[destroyedDimension]; i >= 0; i--) { iterPoint[destroyedDimension] = i; llvm::APInt weight = clearValues[iterPoint]; llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight); llvm::APInt multiplicationNorm = APIntWidthExtendUMul(encryptedOperandNorm, weightNorm); accumulationNorm = APIntWidthExtendUAdd(multiplicationNorm, accumulationNorm); } return accumulationNorm; } else { // descend into all indices in current dimension llvm::APInt maximumNorm = llvm::APInt{1, 1, false}; for (int64_t i = 0; i < shape[0]; i++) { iterPoint[currentDimension] = i; llvm::APInt accumulationNorm = sqMANP_matmul_internal(shape.drop_front(1), destroyedDimension, iterPoint, clearValues, encryptedOperandNorm); maximumNorm = APIntUMax(maximumNorm, accumulationNorm); } return maximumNorm; } } /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation /// that is equivalent to an `FHE.mul_eint_int` operation. static llvm::APInt sqMANP_matmul(llvm::APInt encryptedOperandNorm, mlir::RankedTensorType clearOperandType, std::optional> clearVals, unsigned clearOpNum) { assert(clearOperandType.getElementType().isSignlessInteger() && "Only multiplications with signless integers are currently allowed"); llvm::ArrayRef clearOperandShape = clearOperandType.getShape(); uint64_t clearOperandDims = (uint64_t)clearOperandShape.size(); // if the clear operand is LHS (index 0), then the destroyed dimension is its // last (dims-1) if the clear operand is RHS (index 1), then the destroyed // dimension is its second to last (dims-2) assert(clearOpNum <= 1 && "Cannot determine destroyed dimension: operation " "has more than 2 operands"); size_t destroyedDimension = clearOperandDims == 1 ? 0 : clearOperandDims - 1 - clearOpNum; llvm::APInt accNorm = llvm::APInt{1, 0, false}; if (clearVals.has_value()) accNorm = sqMANP_matmul_internal(clearOperandShape, destroyedDimension, llvm::SmallVector( clearOperandShape.size(), 0), clearVals.value(), encryptedOperandNorm); else { llvm::APInt clearOperandNorm = conservativeIntNorm2Sq(clearOperandType.getElementType()); llvm::APInt mulNorm = APIntWidthExtendUMul(encryptedOperandNorm, clearOperandNorm); uint64_t N = clearOperandShape[destroyedDimension]; unsigned int Nbits = ceilLog2(N + 1); mulNorm = APIntWidthExtendUMul(mulNorm, APInt{Nbits, N, false}); accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } return accNorm; } /// Calculates the squared Minimal Arithmetic Noise Padding of a matmul /// operation static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintEintOp op, llvm::ArrayRef operandMANPs) { auto rhsType = ((mlir::Type)op.getRhs().getType()).cast(); llvm::ArrayRef rhsShape = rhsType.getShape(); int64_t rhsDims = (int64_t)rhsShape.size(); assert( operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value(); llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value(); int64_t N = rhsDims <= 2 ? rhsShape[0] : rhsShape[rhsDims - 2]; // Compute MANP of a single matrix cell x matrix cell multiplication // This is used later to compute the MANP of an entire dot product llvm::APInt tlu = {1, 1, false}; llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu); llvm::APInt accNorm = llvm::APInt{1, 0, false}; // For the total MatMul MANP, take the MANP of a single // column-row dot-product // All such dot-products produce the same MANP, there // is no need to take the maximum over the dot-products for (int64_t i = 0; i < N; i++) { accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm); } return accNorm; } static llvm::APInt getSqMANP(mlir::tensor::ExtractOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); return eNorm; } static std::optional getSqMANP(mlir::tensor::FromElementsOp op, llvm::ArrayRef operandMANPs) { auto max = std::max_element(operandMANPs.begin(), operandMANPs.end(), [](const MANPLattice *a, const MANPLattice *b) { return APIntWidthExtendULT( a->getValue().getMANP().value(), b->getValue().getMANP().value()); }); return (*max)->getValue().getMANP().value(); } static std::optional getSqMANP(mlir::tensor::ExtractSliceOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return operandMANPs[0]->getValue().getMANP().value(); } static std::optional getSqMANP(mlir::tensor::InsertSliceOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 2 && operandMANPs[0]->getValue().getMANP().has_value() && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return APIntUMax(operandMANPs[0]->getValue().getMANP().value(), operandMANPs[1]->getValue().getMANP().value()); } static std::optional getSqMANP(mlir::tensor::InsertOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 2 && operandMANPs[0]->getValue().getMANP().has_value() && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return APIntUMax(operandMANPs[0]->getValue().getMANP().value(), operandMANPs[1]->getValue().getMANP().value()); } static std::optional getSqMANP(mlir::tensor::CollapseShapeOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 1 && operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return operandMANPs[0]->getValue().getMANP().value(); } static std::optional getSqMANP(mlir::tensor::ExpandShapeOp op, llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 1 && operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); return operandMANPs[0]->getValue().getMANP().value(); } static std::optional getSqMANP(mlir::concretelang::FHELinalg::SumOp op, llvm::ArrayRef operandMANPs) { auto inputType = op.getOperand().getType().dyn_cast(); uint64_t numberOfElementsInTheInput = inputType.getNumElements(); if (numberOfElementsInTheInput == 0) { return llvm::APInt{1, 1, false}; } uint64_t numberOfElementsAddedTogetherInEachOutputCell = 1; mlir::ArrayAttr axes = op.getAxes(); if (axes.empty()) { numberOfElementsAddedTogetherInEachOutputCell *= numberOfElementsInTheInput; } else { llvm::ArrayRef shape = inputType.getShape(); for (mlir::Attribute axisAttribute : op.getAxes()) { int64_t axis = axisAttribute.cast().getInt(); numberOfElementsAddedTogetherInEachOutputCell *= shape[axis]; } } unsigned int noiseMultiplierBits = ceilLog2(numberOfElementsAddedTogetherInEachOutputCell + 1); auto noiseMultiplier = llvm::APInt{ noiseMultiplierBits, numberOfElementsAddedTogetherInEachOutputCell, false, }; assert(operandMANPs.size() == 1 && operandMANPs[0]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().value(); return APIntWidthExtendUMul(noiseMultiplier, operandMANP); } static std::optional getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op, llvm::ArrayRef operandMANPs) { llvm::APInt result = llvm::APInt{1, 0, false}; for (const MANPLattice *operandMANP : operandMANPs) { llvm::APInt candidate = operandMANP->getValue().getMANP().value(); if (candidate.getLimitedValue() >= result.getLimitedValue()) { result = candidate; } } return result; } static std::optional getSqMANP(mlir::concretelang::FHELinalg::FancyIndexOp op, llvm::ArrayRef operandMANPs) { return operandMANPs[0]->getValue().getMANP().value(); } static std::optional getSqMANP(mlir::concretelang::FHELinalg::FancyAssignOp op, llvm::ArrayRef operandMANPs) { auto inputManp = operandMANPs[0]->getValue().getMANP().value(); auto valuesManp = operandMANPs[2]->getValue().getMANP().value(); return APIntUMax(inputManp, valuesManp); } static llvm::APInt sqMANP_conv2d(llvm::APInt inputNorm, mlir::RankedTensorType weightTy, std::optional> weightVals) { // Initial value of the accumulator to 0 llvm::APInt accNorm = llvm::APInt{1, 0, false}; // Weight shapes: Filter*Channel*Height*Width uint64_t F = weightTy.getShape()[0]; uint64_t C = weightTy.getShape()[1]; uint64_t H = weightTy.getShape()[2]; uint64_t W = weightTy.getShape()[3]; if (weightVals.has_value()) { // For a constant weight kernel use actual constant to calculate 2-norm // input windows are being multiplied by a kernel and summed up for (uint64_t f = 0; f < F; f++) { llvm::APInt tmpNorm = inputNorm; for (uint64_t c = 0; c < C; c++) { for (uint64_t h = 0; h < H; h++) { for (uint64_t w = 0; w < W; w++) { llvm::APInt cst = weightVals.value()[{f, c, h, w}]; llvm::APInt weightNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); } } } // Take the max of the 2-norm on the filter accNorm = APIntUMax(accNorm, tmpNorm); } } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width llvm::APInt weightNorm = conservativeIntNorm2Sq(weightTy.getElementType()); // For a weight (kernel) of shape tensor, there is C*H*W // FHE.mul_eint_int and FHE.add_eint operations for each elements of the // result int64_t n_mul = C * H * W; llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); for (int64_t i = 0; i < n_mul; i++) { accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } } return accNorm; } class MANPAnalysis : public mlir::dataflow::SparseDataFlowAnalysis { public: explicit MANPAnalysis(mlir::DataFlowSolver &solver, bool debug) : mlir::dataflow::SparseDataFlowAnalysis(solver), debug(debug) {} void setToEntryState(MANPLattice *lattice) override { if (isEncryptedFunctionParameter(lattice->getPoint())) { // Set minimal MANP for encrypted function arguments propagateIfChanged(lattice, lattice->join(MANPLatticeValue{ std::optional{llvm::APInt(1, 1)}})); } // In case of block arguments used in the block of a linalg.genric // operation: map the MANP values of the operands into the block arguments else if (lattice->getPoint().isa() || mlir::concretelang::fhe::utils::isEncryptedValue( lattice->getPoint())) { mlir::Block *block = lattice->getPoint().cast().getOwner(); if (block && block->getParentOp() && llvm::isa(block->getParentOp())) { auto genericOp = mlir::dyn_cast(block->getParentOp()); // Get the MANP from the corresponding input/output auto argIndex = lattice->getPoint().cast().getArgNumber(); auto operandRange = genericOp.getInputs(); if (argIndex >= operandRange.size()) { argIndex -= operandRange.size(); operandRange = genericOp.getOutputs(); } auto v = operandRange[argIndex]; auto manp = this->getLatticeElement(v)->getValue().getMANP().value_or( llvm::APInt(1, 1)); propagateIfChanged(lattice, lattice->join(MANPLatticeValue{manp})); } } else { // Everything else is initialized with an unset value propagateIfChanged(lattice, lattice->join(MANPLatticeValue{})); } } std::optional norm2SqEquivFromOp(Operation *op, ArrayRef operands) { std::optional norm2SqEquiv; if (auto cstNoiseOp = llvm::dyn_cast(op)) { if (llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 0, false}; } else { norm2SqEquiv = llvm::APInt{1, 1, false}; } } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = getNoOpSqMANP(operands); } // FHE and FHELinalg Operators else if (auto unaryEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(unaryEintOp, operands); } else if (auto binaryEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(binaryEintIntOp, operands); } else if (auto binaryIntEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(binaryIntEintOp, operands); } else if (auto binaryEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(binaryEintOp, operands); } else if (auto dotEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(dotEintOp, operands); } else if (auto matmulEintEintOp = llvm::dyn_cast< mlir::concretelang::FHELinalg::MatMulEintEintOp>(op)) { norm2SqEquiv = getSqMANP(matmulEintEintOp, operands); } else if (auto sumOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(sumOp, operands); } else if (auto concatOp = llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(concatOp, operands); } else if (auto fromElementOp = llvm::dyn_cast( op)) { if (operands[0]->getValue().getMANP().has_value()) { norm2SqEquiv = operands[0]->getValue().getMANP().value(); } else norm2SqEquiv = llvm::APInt{1, 1, false}; } else if (auto fancyIndexOp = llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(fancyIndexOp, operands); } else if (auto fancyAssignOp = llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(fancyAssignOp, operands); } // Tensor Operators // ExtractOp else if (auto extractOp = llvm::dyn_cast(op)) { if (extractOp.getResult() .getType() .isa()) { norm2SqEquiv = getSqMANP(extractOp, operands); } else { norm2SqEquiv = {}; } } // ExtractSliceOp else if (auto extractSliceOp = llvm::dyn_cast(op)) { if (extractSliceOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(extractSliceOp, operands); } else { norm2SqEquiv = {}; } } // InsertOp else if (auto insertOp = llvm::dyn_cast(op)) { if (insertOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(insertOp, operands); } else { norm2SqEquiv = {}; } } // InsertSliceOp else if (auto insertSliceOp = llvm::dyn_cast(op)) { if (insertSliceOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(insertSliceOp, operands); } else { norm2SqEquiv = {}; } } // FromElementOp else if (auto fromOp = llvm::dyn_cast(op)) { if (fromOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(fromOp, operands); } else { norm2SqEquiv = {}; } } // TensorCollapseShapeOp else if (auto reshapeOp = llvm::dyn_cast(op)) { if (reshapeOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { norm2SqEquiv = {}; } } // TensorExpandShapeOp else if (auto reshapeOp = llvm::dyn_cast(op)) { if (reshapeOp.getResult() .getType() .cast() .getElementType() .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { norm2SqEquiv = {}; } } // Linalg Generic else if (auto linalgGenericOp = llvm::dyn_cast(op)) { norm2SqEquiv = emulateLinalgGenric(linalgGenericOp, operands); } else if (llvm::isa(op)) { norm2SqEquiv = {}; } else if (llvm::isa( *op->getDialect())) { op->emitError("Unsupported operation"); assert(false && "Unsupported operation"); } else { norm2SqEquiv = {}; } return norm2SqEquiv; } void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override { MANPLattice *latticeRes = results[0]; std::optional norm2SqEquiv = norm2SqEquivFromOp(op, operands); if (norm2SqEquiv.has_value()) { latticeRes->join(MANPLatticeValue{norm2SqEquiv}); op->setAttr("SMANP", mlir::IntegerAttr::get( mlir::IntegerType::get( op->getContext(), norm2SqEquiv.value().getBitWidth(), mlir::IntegerType::SignednessSemantics::Unsigned), norm2SqEquiv.value())); llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv.value()); op->setAttr("MANP", mlir::IntegerAttr::get( mlir::IntegerType::get( op->getContext(), norm2Equiv.getBitWidth(), mlir::IntegerType::SignednessSemantics::Unsigned), norm2Equiv)); if (debug) { op->emitRemark("Squared Minimal Arithmetic Noise Padding: ") << APIntToStringValUnsigned(norm2SqEquiv.value()) << "\n"; } } else { latticeRes->join(MANPLatticeValue{}); } } /// Compute the flat index of a tensor given its shape, the current loop /// indices, and the map between loop and tensor indices size_t indexFromLoopRange(mlir::SmallVector loopIndices, mlir::AffineMap map, mlir::ArrayRef shape) { llvm::SmallVector tensorIndices; llvm::SmallVector constantIndices; for (auto i : loopIndices) { constantIndices.push_back( IntegerAttr::get(IntegerType::get(map.getContext(), 64), i)); } assert(map.constantFold(constantIndices, tensorIndices).succeeded()); assert(tensorIndices.size() == shape.size()); int64_t multiplier = 1; size_t index = 0; for (int64_t i = shape.size() - 1; i >= 0; i--) { index += tensorIndices[i].cast().getInt() * multiplier; multiplier *= shape[i]; } return index; } // Compute the MANP value of a linalg.generic operation by emulating its // execution std::optional emulateLinalgGenric(mlir::linalg::GenericOp genericOp, llvm::ArrayRef operandMANPs) { assert(genericOp.getOutputs().size() == 1 && "MANP doesn't support linalg.genric with more than one output"); // We want to use a different mechanism to store MANP values than the // Analysis. We don't want to write anything to the lattice values // controlled by the analysis, but we will use them to read values that we // don't yet have in the emulation mlir::DenseMap valueToManp; auto fetchOrFallbackToAnalysis = [&](mlir::Value value) -> MANPLatticeValue * { auto elem = valueToManp.find(value); if (elem != valueToManp.end()) { return &elem->second; } auto manp = getLatticeElement(value)->getValue(); valueToManp[value] = manp; return &valueToManp[value]; }; auto loopRange = mlir::concretelang::fhe::utils::getLinalgGenericLoopRange(genericOp); auto iterCount = std::accumulate(loopRange.begin(), loopRange.end(), 1, std::multiplies()); llvm::SmallVector strides; auto stride = iterCount; for (size_t i = 0; i < loopRange.size(); i++) { stride /= loopRange[i]; strides.push_back(stride); } // clone the genricOp to replace block arguments with constant values when // needed. The clone op must be destroyed at the end of the function auto genericOpClone = genericOp.clone(); // init block arguments' MANP: map block arguments with op operands for (auto arg : genericOpClone.getBlock()->getArguments()) { auto argIndex = arg.getArgNumber(); auto operandRange = genericOpClone.getInputs(); if (argIndex >= operandRange.size()) { argIndex -= operandRange.size(); operandRange = genericOpClone.getOutputs(); } valueToManp[arg] = getLatticeElement(operandRange[argIndex])->getValue(); } // keep track of the MANP of different elements in the output tensor // (initialized to the initial output MANP value) auto outputArg = genericOpClone.getBlock()->getArguments().back(); auto outputType = genericOpClone.getOutputs().front().getType().cast(); auto outputSize = std::accumulate(outputType.getShape().begin(), outputType.getShape().end(), 1, std::multiplies()); std::vector outputMANPs( outputSize, fetchOrFallbackToAnalysis(outputArg)->getMANP().value()); // indices at a specific iteration llvm::SmallVector indices(loopRange.size(), 0); for (auto i = 0; i < iterCount; i++) { for (size_t iterPos = 0; iterPos < indices.size(); iterPos++) { indices[iterPos] = (i / strides[iterPos]) % loopRange[iterPos]; } // if a linalg genric input is constant, replace the uses of its // respective block argument with a constant value. This avoids the // computation of the MANP to use conservative values. // we also have to replace them back for the next iteration to be able to // update them again with new values mlir::DenseMap toReplaceBack; for (auto arg : genericOpClone.getBlock()->getArguments()) { auto argIndex = arg.getArgNumber(); auto inputs = genericOpClone.getInputs(); // don't consider outputs if (argIndex >= inputs.size()) continue; auto input = inputs[argIndex]; auto definingOp = input.getDefiningOp(); if (definingOp && mlir::isa(definingOp)) { // fetch constant value auto constantOp = mlir::dyn_cast(definingOp); mlir::DenseIntElementsAttr denseAttr = constantOp.getValueAttr().dyn_cast(); auto constantTensor = denseAttr.getValues(); auto constantIndex = indexFromLoopRange( indices, genericOpClone.getIndexingMapsArray()[argIndex], denseAttr.getType().getShape()); APInt constantValue = constantTensor[constantIndex]; // create new constant op with constant value auto opBuilder = mlir::OpBuilder(genericOpClone.getContext()); auto opState = mlir::OperationState( mlir::UnknownLoc::get(genericOpClone.getContext()), arith::ConstantOp::getOperationName()); arith::ConstantOp::build( opBuilder, opState, mlir::IntegerAttr::get(denseAttr.getType().getElementType(), constantValue)); auto newConstantOp = arith::ConstantOp(mlir::Operation::create(opState)); genericOpClone.getBlock()->push_front(newConstantOp); // replace uses of the block argument with the constant value arg.replaceAllUsesWith(newConstantOp.getResult()); toReplaceBack[arg] = newConstantOp; } } // we want to replace the MANP of the block argument corresponding to the // output with the MANP value corresponding to the currently accessed // tensor element size_t outputIndex = indexFromLoopRange( indices, genericOpClone.getIndexingMapsArray()[outputArg.getArgNumber()], outputType.getShape()); valueToManp[outputArg] = MANPLatticeValue(outputMANPs[outputIndex]); genericOpClone.getBody()->walk([&](mlir::Operation *op) { // we update the appropriate element's MANP value using the index of the // currently accessed output element if (auto yieldOp = mlir::dyn_cast(op)) { auto manp = fetchOrFallbackToAnalysis(yieldOp->getOperand(0)); outputMANPs[outputIndex] = manp->getMANP().value(); return; } // compute using the op and operand manp values mlir::SmallVector latticeOperands; for (auto operand : op->getOperands()) { auto lattice = new MANPLattice(operand); lattice->join(*fetchOrFallbackToAnalysis(operand)); latticeOperands.push_back(lattice); } std::optional norm2SqEquiv = norm2SqEquivFromOp(op, latticeOperands); // update the MANP of the result value if (op->getNumResults() > 0) { valueToManp[op->getResult(0).cast()] = MANPLatticeValue(norm2SqEquiv); } // free space of lattice elements for (auto toFree : latticeOperands) { delete toFree; } }); // replace back the uses of block arguments which were replaced by // constant values for (auto replacement : toReplaceBack) { auto blockArg = replacement.first; auto constantOp = replacement.second; auto constantValue = constantOp.getResult(); constantValue.replaceAllUsesWith(blockArg); constantOp->remove(); constantOp->destroy(); } } genericOpClone->destroy(); // final result MANP is the max of output llvm::APInt result = outputMANPs[0]; for (auto manp : outputMANPs) { result = APIntUMax(result, manp); } return result; } private: bool debug; }; } // namespace namespace FHE { llvm::APInt AddEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return APIntWidthExtendUAdd(a, b); } llvm::APInt SubEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return APIntWidthExtendUAdd(a, b); } llvm::APInt MulEintIntOp::sqMANP(llvm::APInt a) { return sqMANP_mul_eint_int( a, this->operandIntType(this->getClearOperandNumber()), this->operandMaxConstant(this->getClearOperandNumber())); } llvm::APInt MulEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return sqMANP_mul_eint(a, b); } llvm::APInt MaxEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { // max(a, b) = max(a - b, 0) + b const llvm::APInt sub = APIntWidthExtendUAdd(a, b); const llvm::APInt tlu = {1, 1, false}; const llvm::APInt add = APIntWidthExtendUAdd(tlu, b); // this is not optimal as it can increase the resulting noise unnecessarily return APIntUMax(add, sub); } llvm::APInt RoundEintOp::sqMANP(llvm::APInt a) { uint64_t inputWidth = this->getOperand().getType().cast().getWidth(); uint64_t outputWidth = this->getResult().getType().cast().getWidth(); uint64_t clearedBits = inputWidth - outputWidth; return a + clearedBits; } } // namespace FHE namespace FHELinalg { llvm::APInt AddEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return APIntWidthExtendUAdd(a, b); } llvm::APInt SubEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return APIntWidthExtendUAdd(a, b); } llvm::APInt MulEintIntOp::sqMANP(llvm::APInt a) { return sqMANP_mul_eint_int( a, this->operandIntType(this->getClearOperandNumber()), this->operandMaxConstant(this->getClearOperandNumber())); } llvm::APInt MulEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { return sqMANP_mul_eint(a, b); } llvm::APInt Dot::sqMANP(llvm::APInt a) { unsigned clearOpNum = this->getClearOperandNumber(); auto clearOperandType = this->getOperation() ->getOpOperand(clearOpNum) .get() .getType() .cast(); return sqMANP_matmul(a, clearOperandType, this->opTensorConstant(clearOpNum), clearOpNum); } llvm::APInt MatMulEintIntOp::sqMANP(llvm::APInt a) { unsigned clearOpNum = this->getClearOperandNumber(); auto clearOpType = this->getOperation() ->getOpOperand(clearOpNum) .get() .getType() .cast(); return sqMANP_matmul(a, clearOpType, this->opTensorConstant(clearOpNum), clearOpNum); } llvm::APInt MatMulIntEintOp::sqMANP(llvm::APInt a) { unsigned clearOpNum = this->getClearOperandNumber(); assert(clearOpNum <= 1 && "Operation has more than 2 operands"); auto clearOpType = this->getOperation() ->getOpOperand(clearOpNum) .get() .getType() .cast(); return sqMANP_matmul(a, clearOpType, this->opTensorConstant(clearOpNum), clearOpNum); } llvm::APInt Conv2dOp::sqMANP(llvm::APInt a) { unsigned clearOpNum = this->getClearOperandNumber(); auto clearOpType = this->getOperation() ->getOpOperand(clearOpNum) .get() .getType() .cast(); return sqMANP_conv2d(a, clearOpType, this->opTensorConstant(clearOpNum)); } llvm::APInt Maxpool2dOp::sqMANP(llvm::APInt a) { // maximum between two value is calculated using // - max(x - y, 0) + y // max is calculated with a TLU so MANP is {1, 1, false} // y on the other hand comes from the input or from the previous result // in the current implementation, it's the input // so the resulting MANP is `{1, 1, false} + MANP input` const llvm::APInt tlu = {1, 1, false}; const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, a); const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, a); return APIntUMax(forIntermediate, forResult); } llvm::APInt RoundOp::sqMANP(llvm::APInt a) { const uint64_t inputWidth = this->getOperand() .getType() .cast() .getElementType() .cast() .getWidth(); const uint64_t outputWidth = this->getResult() .getType() .cast() .getElementType() .cast() .getWidth(); const uint64_t clearedBits = inputWidth - outputWidth; return a + clearedBits; } } // namespace FHELinalg namespace { /// For documentation see MANP.td struct MANPPass : public MANPBase { void runOnOperation() override { mlir::func::FuncOp func = getOperation(); mlir::DataFlowSolver solver; solver.load(); solver.load(debug); if (failed(solver.initializeAndRun(func))) return signalPassFailure(); } MANPPass() = delete; MANPPass(bool debug) : debug(debug){}; protected: bool debug; }; } // end anonymous namespace /// Create an instance of the Minimal Arithmetic Noise Padding analysis /// pass. If `debug` is true, for each operation, the pass emits a /// remark containing the squared Minimal Arithmetic Noise Padding of /// the equivalent dot operation. std::unique_ptr createMANPPass(bool debug) { return std::make_unique(debug); } namespace { /// For documentation see MANP.td struct MaxMANPPass : public MaxMANPBase { void runOnOperation() override { mlir::func::FuncOp func = getOperation(); func.walk( [&](mlir::Operation *childOp) { this->processOperation(childOp); }); } MaxMANPPass() = delete; MaxMANPPass(std::function updateMax) : updateMax(updateMax){}; protected: void processOperation(mlir::Operation *op) { // Process all function arguments and use the default value of 1 // for MANP and the declarend precision if (mlir::func::FuncOp func = llvm::dyn_cast_or_null(op)) { for (mlir::BlockArgument blockArg : func.getBody().getArguments()) { if (isEncryptedFunctionParameter(blockArg)) { unsigned int width = fhe::utils::getEintPrecision(blockArg); this->updateMax(1, width); } } } // Process all results using MANP attribute from MANP pas for (mlir::OpResult res : op->getResults()) { mlir::concretelang::FHE::FheIntegerInterface eTy = res.getType() .dyn_cast_or_null(); if (eTy == nullptr) { auto tensorTy = res.getType().dyn_cast_or_null(); if (tensorTy != nullptr) { eTy = tensorTy.getElementType() .dyn_cast_or_null< mlir::concretelang::FHE::FheIntegerInterface>(); } } if (eTy) { mlir::IntegerAttr MANP = op->getAttrOfType("MANP"); if (!MANP) { op->emitError("2-Norm has not been computed"); this->signalPassFailure(); return; } auto manp = MANP.getValue(); if (!manp.isIntN(64)) { op->emitError("2-Norm cannot be reprensented on 64bits"); this->signalPassFailure(); return; } this->updateMax(manp.getZExtValue(), eTy.getWidth()); } } } std::function updateMax; }; } // end anonymous namespace std::unique_ptr createMaxMANPPass(std::function updateMax) { return std::make_unique(updateMax); } } // namespace concretelang } // namespace mlir