From ed762942c15c2b3204d06549a0c01fe7e7486ba6 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 3 Sep 2021 12:01:56 +0200 Subject: [PATCH] feat(compiler): Add pass for Minimal Arithmetic Noise Padding This pass calculates the squared Minimal Arithmetic Noise Padding (MANP) for each operation of a function and stores the result in an integer attribute named "sqMANP". This metric is identical to the squared 2-norm of the constant vector of an equivalent dot product between a vector of encrypted integers resulting directly from an encryption and a vector of plaintext constants. The pass supports the following operations: - HLFHE.dot_eint_int - HLFHE.zero - HLFHE.add_eint_int - HLFHE.add_eint - HLFHE.sub_int_eint - HLFHE.mul_eint_int - HLFHE.apply_lookup_table If any other operation is encountered, the pass conservatively fails. The pass further makes the optimistic assumption that all values passed to a function are either the direct result of an encryption of a noise-refreshing operation. --- .../Dialect/HLFHE/Analysis/CMakeLists.txt | 7 + .../zamalang/Dialect/HLFHE/Analysis/MANP.h | 12 + .../zamalang/Dialect/HLFHE/Analysis/MANP.td | 95 ++++ .../zamalang/Dialect/HLFHE/CMakeLists.txt | 3 +- .../lib/Dialect/HLFHE/Analysis/CMakeLists.txt | 15 + compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 485 ++++++++++++++++++ compiler/lib/Dialect/HLFHE/CMakeLists.txt | 3 +- compiler/src/CMakeLists.txt | 1 + compiler/src/main.cpp | 25 +- 9 files changed, 643 insertions(+), 3 deletions(-) create mode 100644 compiler/include/zamalang/Dialect/HLFHE/Analysis/CMakeLists.txt create mode 100644 compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h create mode 100644 compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td create mode 100644 compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt create mode 100644 compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHE/Analysis/CMakeLists.txt new file mode 100644 index 000000000..71fc391bf --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/CMakeLists.txt @@ -0,0 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS MANP.td) +mlir_tablegen(MANP.h.inc -gen-pass-decls -name Analysis) +mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis) +mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) +add_public_tablegen_target(MANPPassIncGen) + +add_mlir_doc(MANP GeneralPasses ./ -gen-pass-doc) diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h new file mode 100644 index 000000000..32c364948 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h @@ -0,0 +1,12 @@ +#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H +#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H + +#include + +namespace mlir { +namespace zamalang { +std::unique_ptr createMANPPass(bool debug = false); +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td new file mode 100644 index 000000000..cda14a55d --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td @@ -0,0 +1,95 @@ +#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP +#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP + +include "mlir/Pass/PassBase.td" + +def MANP : FunctionPass<"MANP"> { + let summary = "HLFHE Minimal Arithmetic Noise Padding Pass"; + let description = [{ + This pass calculates the Minimal Arithmetic Noise Padding + (MANP) for each operation of a function and stores the result in an + integer attribute named "MANP". This metric is identical to the + ceiled 2-norm of the constant vector of an equivalent dot product + between a vector of encrypted integers resulting directly from an + encryption and a vector of plaintext constants. + + The pass supports the following operations: + + - HLFHE.dot_eint_int + - HLFHE.zero + - HLFHE.add_eint_int + - HLFHE.add_eint + - HLFHE.sub_int_eint + - HLFHE.mul_eint_int + - HLFHE.apply_lookup_table + + If any other operation is encountered, the pass conservatively + fails. The pass further makes the optimistic assumption that all + values passed to a function are either the direct result of an + encryption of a noise-refreshing operation. + + Conceptually, the pass is equivalent to the three steps below: + + 1. Replace all arithmetic operations with an equivalent dot + operation + + 2. Merge resulting dot operations into a single, equivalent + dot operation + + 3. Calculate the 2-norm of the vector of plaintext constants + of the dot operation + + with the following replacement rules: + + - Function argument a -> HLFHE.dot_eint_int([a], [1]) + - HLFHE.apply_lookup_table -> HLFHE.dot_eint_int([LUT result], [1]) + - HLFHE.zero() -> HLFHE.dot_eint_int([encrypted 0], [1]) + - HLFHE.add_eint_int(e, c) -> HLFHE.dot_eint_int([e, 1], [1, c]) + - HLFHE.add_eint(e0, e1) -> HLFHE.dot_eint_int([e0, e1], [1, 1]) + - HLFHE.sub_int_eint(c, e) -> HLFHE.dot_eint_int([e, c], [1, -1]) + - HLFHE.mul_eint_int(e, c) -> HLFHE.dot_eint_int([e], [c]) + + Dependent dot operations, e.g., + + a = HLFHE.dot_eint_int([a0, a1, ...], [c0, c1, ...]) + b = HLFHE.dot_eint_int([b0, b1, ...], [d0, d1, ...]) + x = HLFHE.dot_eint_int([a, b, ...], [f0, f1, ...]) + + are merged as follows: + + x = HLFHE.dot_eint_int([a0, a1, ..., b0, b1, ...], + [f0*c0, f0*c1, ..., f1*d0, f1*d1, ...]) + + However, the implementation does not explicitly create the + equivalent dot operations, but only accumulates the squared 2-norm + of the constant vector of the equivalent dot operation along the + edges of the data-flow graph composed by the operations in order to + calculate the final 2-norm for the final single dot operation above. + + For the example above, this means that the pass calculates the + squared 2-norm of x, sqN(x) as: + + sqN(a) = c0*c0 + c1*c1 + ... + sqN(b) = d0*d0 + d1*d1 + ... + sqN(x) = f0*f0*c0*c0 + f0*f0*c1*c1 + ... + f1*f1*d0*d0 + f1*f1*d1*d1 + ... + = f0*f0*sqN(a) + f1*f1*sqN(b) + + This leads to the following rules to calculate the squared 2-norm + for the supported operations: + + - Function argument -> 1 + - HLFHE.apply_lookup_table -> 1 + - HLFHE.zero() -> 1 + - HLFHE.dot_eint_int([e0, e1, ...], [c0, c1, ...]) -> + c0*c0*sqN(e0) + c1*c1*sqN(e1) + ... + - HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c + - HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2) + - HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c + - HLFHE.mul_eint_int(e, c) -> c*c*sqN(e) + + The final, non-squared 2-norm of an operation is the square root of the + squared value rounded to the next highest integer. + }]; +} + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHE/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHE/CMakeLists.txt index 7d59dce8e..4f7494893 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/CMakeLists.txt +++ b/compiler/include/zamalang/Dialect/HLFHE/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(IR) \ No newline at end of file +add_subdirectory(Analysis) +add_subdirectory(IR) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt b/compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt new file mode 100644 index 000000000..570db4b9f --- /dev/null +++ b/compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_library(HLFHEDialectAnalysis + MANP.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE + + DEPENDS + HLFHEDialect + MANPPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + HLFHEDialect) + +target_link_libraries(HLFHEDialectAnalysis PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp new file mode 100644 index 000000000..c454a054f --- /dev/null +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -0,0 +1,485 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { +namespace { +// The `MANPLatticeValue` represents the squared Minimal Arithmetic +// Noise Padding for an operation using the squared 2-norm of an +// equivalent dot operation. This can either be an actual value if the +// values for its predecessors have been calculated beforehand or an +// unknown value otherwise. +struct MANPLatticeValue { + MANPLatticeValue(llvm::Optional manp = {}) : manp(manp) {} + + static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) { + return MANPLatticeValue(); + } + + static MANPLatticeValue getPessimisticValueState(mlir::Value value) { + // Function arguments are assumed to require a Minimal Arithmetic + // Noise Padding with a 2-norm of 1. + // + // TODO: Provide a mechanism to propagate Minimal Arithmetic Noise + // Padding across function calls. + if (value.isa() && + (value.getType().isa() || + (value.getType().isa() && + value.getType() + .cast() + .getElementType() + .isa()))) { + return MANPLatticeValue(llvm::APInt{1, 1, false}); + } else { + // All other operations have an unknown Minimal Arithmetic Noise + // Padding until an value for all predecessors has been + // calculated. + return MANPLatticeValue(); + } + } + + bool operator==(const MANPLatticeValue &rhs) const { + return this->manp == rhs.manp; + } + + // Required by `mlir::LatticeElement::join()`, but should never be + // invoked, as `MANPAnalysis::visitOperation()` takes care of + // combining the squared Minimal Arithmetic Noise Padding of + // operands into the Minimal Arithmetic Noise Padding of the result. + static MANPLatticeValue join(const MANPLatticeValue &lhs, + const MANPLatticeValue &rhs) { + assert(false && "Minimal Arithmetic Noise Padding values can only be " + "combined sensibly when the combining operation is known"); + return MANPLatticeValue{}; + } + + llvm::Optional getMANP() { return manp; } + +protected: + llvm::Optional manp; +}; + +// Checks if `lhs` is less than `rhs`, where both values are assumed +// to be positive. The bit width of the smaller `APInt` is extended +// before comparison via `APInt::ult`. +static bool APIntWidthExtendULT(const llvm::APInt &lhs, + const llvm::APInt &rhs) { + if (lhs.getBitWidth() < rhs.getBitWidth()) + return lhs.zext(rhs.getBitWidth()).ult(rhs); + else if (lhs.getBitWidth() > rhs.getBitWidth()) + return lhs.ult(rhs.zext(lhs.getBitWidth())); + else + return lhs.ult(rhs); +} + +// Adds two `APInt` values, where both values are assumed to be +// positive. The bit width of the operands is extended in order to +// guarantee that the sum fits into the resulting `APInt`. +static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs, + const llvm::APInt &rhs) { + unsigned maxBits = std::max(lhs.getBitWidth(), rhs.getBitWidth()); + + // Make sure the required number of bits can be represented by the + // `unsigned` argument of `zext`. + assert(std::numeric_limits::max() - maxBits > 1); + + unsigned targetWidth = maxBits + 1; + return lhs.zext(targetWidth) + rhs.zext(targetWidth); +} + +// Multiplies two `APInt` values, where both values are assumed to be +// positive. The bit width of the operands is extended in order to +// guarantee that the product fits into the resulting `APInt`. +static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs, + const llvm::APInt &rhs) { + // Make sure the required number of bits can be represented by the + // `unsigned` argument of `zext`. + assert(std::numeric_limits::max() - + std::max(lhs.getBitWidth(), rhs.getBitWidth()) > + std::min(lhs.getBitWidth(), rhs.getBitWidth()) && + "Required number of bits cannot be represented with an APInt"); + + unsigned targetWidth = lhs.getBitWidth() + rhs.getBitWidth(); + return lhs.zext(targetWidth) * rhs.zext(targetWidth); +} + +// Calculates the square of `i`. The bit width `i` is extended in +// order to guarantee that the product fits into the resulting +// `APInt`. +static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) { + // Make sure the required number of bits can be represented by the + // `unsigned` argument of `zext`. + assert(i.getBitWidth() < std::numeric_limits::max() / 2 && + "Required number of bits cannot be represented with an APInt"); + + llvm::APInt ie = i.zext(2 * i.getBitWidth()); + + return ie * ie; +} + +// Calculates the square root of `i` and rounds it to the next highest +// integer value (i.e., the square of the result is guaranteed to be +// greater or equal to `i`). +static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) { + llvm::APInt res = i.sqrt(); + llvm::APInt resSq = APIntWidthExtendUSq(res); + + if (APIntWidthExtendULT(resSq, i)) + return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false}); + else + return res; +} + +// Returns a string representation of `i` assuming that `i` is an +// unsigned value. +static std::string APIntToStringValUnsigned(const llvm::APInt &i) { + llvm::SmallString<32> s; + i.toStringUnsigned(s); + return std::string(s.c_str()); +} + +// Calculates the square of the 2-norm of a tensor initialized with a +// dense matrix of constant, signless integers. Aborts if the value +// type or initialization of of `cstOp` is incorrect. +static llvm::APInt denseCstTensorNorm2Sq(mlir::ConstantOp cstOp) { + mlir::DenseIntElementsAttr denseVals = + cstOp->getAttrOfType("value"); + + assert(denseVals && cstOp.getType().isa() && + "Constant must be a tensor initialized with `dense`"); + + mlir::TensorType tensorType = cstOp.getType().cast(); + + assert(tensorType.getElementType().isSignlessInteger() && + "Can only handle tensors with signless integer elements"); + + mlir::IntegerType elementType = + tensorType.getElementType().cast(); + + llvm::APInt accu{1, 0, false}; + + for (llvm::APInt val : denseVals.getIntValues()) { + llvm::APInt valSq = APIntWidthExtendUSq(val); + accu = APIntWidthExtendUAdd(accu, valSq); + } + + return accu; +} + +// Calculates (T)ceil(log2f(v)) +// TODO: Replace with some fancy bit twiddling hack +template static T ceilLog2(const T v) { + T tmp = v; + T log2 = 0; + + while (tmp >>= 1) + log2++; + + // If more than MSB set, round to next highest power of 2 + if (v & ~((T)1 << log2)) + log2 += 1; + + return log2; +} + +// Calculates the square of the 2-norm of a 1D tensor of signless +// integers by conservatively assuming that the dynamic values are the +// maximum for the integer width. Aborts if the tensor type `tTy` is +// incorrect. +static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) { + assert(tTy && tTy.getElementType().isSignlessInteger() && + tTy.hasStaticShape() && tTy.getRank() == 1 && + "Plaintext operand must be a statically shaped 1D tensor of integers"); + + // Make sure the log2 of the number of elements fits into an + // unsigned + assert(std::numeric_limits::max() > 8 * sizeof(uint64_t)); + + unsigned elWidth = tTy.getElementTypeBitWidth(); + + llvm::APInt maxVal = APInt::getMaxValue(elWidth); + llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal); + + // Calculate number of bits for APInt to store number of elements + uint64_t nElts = (uint64_t)tTy.getNumElements(); + assert(std::numeric_limits::max() - nElts > 1); + unsigned nEltsBits = (unsigned)ceilLog2(nElts + 1); + + llvm::APInt nEltsAP{nEltsBits, nElts, false}; + + return APIntWidthExtendUMul(maxValSq, nEltsAP); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of an +// `HLFHE.dot_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::Dot op, + llvm::ArrayRef *> operandMANPs) { + assert(op->getOpOperand(0).get().isa() && + "Only dot operations with tensors that are function arguments are " + "currently supported"); + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + + if (cstOp) { + // Dot product between a vector of encrypted integers and a vector + // of plaintext constants -> return 2-norm of constant vector + return denseCstTensorNorm2Sq(cstOp); + } else { + // Dot product between a vector of encrypted integers and a vector + // of dynamic plaintext values -> conservatively assume that all + // the values are the maximum possible value for the integer's + // width + mlir::TensorType tTy = op->getOpOperand(1) + .get() + .getType() + .dyn_cast_or_null(); + + return denseDynTensorNorm2Sq(tTy); + } +} + +// Returns the squared 2-norm for a dynamic integer by conservatively +// assuming that the integer's value is the maximum for the integer +// width. +static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { + assert(t.isSignlessInteger() && "Type must be a signless integer type"); + assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); + + llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false}; + maxVal <<= t.getIntOrFloatBitWidth(); + return APIntWidthExtendUSq(maxVal); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of an +// `HLFHE.add_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::AddEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::Type iTy = op->getOpOperand(1).get().getType(); + + assert(iTy.isSignlessInteger() && + "Only additions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + if (cstOp) { + // For a constant operand use actual constant to calculate 2-norm + mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); + sqNorm = APIntWidthExtendUSq(attr.getValue()); + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + sqNorm = conservativeIntNorm2Sq(iTy); + } + + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.add_eint` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::AddEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue(); + + return APIntWidthExtendUAdd(a, b); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.sub_int_eint` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::SubIntEintOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::Type iTy = op->getOpOperand(0).get().getType(); + + assert(iTy.isSignlessInteger() && + "Only subtractions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(0).get().getDefiningOp()); + + if (cstOp) { + // For constant plaintext operands simply use the constant value + mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); + sqNorm = APIntWidthExtendUSq(attr.getValue()); + } else { + // For dynamic plaintext operands conservatively assume that the integer has + // its maximum possible value + sqNorm = conservativeIntNorm2Sq(iTy); + } + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.mul_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::MulEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::Type iTy = op->getOpOperand(1).get().getType(); + + assert(iTy.isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + if (cstOp) { + // For a constant operand use actual constant to calculate 2-norm + mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); + sqNorm = APIntWidthExtendUSq(attr.getValue()); + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + sqNorm = conservativeIntNorm2Sq(iTy); + } + + return APIntWidthExtendUMul(sqNorm, eNorm); +} + +struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + MANPAnalysis(mlir::MLIRContext *ctx, bool debug) + : debug(debug), mlir::ForwardDataFlowAnalysis(ctx) {} + + ~MANPAnalysis() override = default; + + mlir::ChangeResult visitOperation( + mlir::Operation *op, + llvm::ArrayRef *> operands) final { + mlir::LatticeElement &latticeRes = + getLatticeElement(op->getResult(0)); + bool isDummy = false; + llvm::APInt norm2SqEquiv; + + if (auto dotOp = llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(dotOp, operands); + } else if (auto addEintIntOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(addEintIntOp, operands); + } else if (auto addEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(addEintOp, operands); + } else if (auto subIntEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(subIntEintOp, operands); + } else if (auto mulEintIntOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (llvm::isa(op) || + llvm::isa(op)) { + norm2SqEquiv = llvm::APInt{1, 1, false}; + } else if (llvm::isa(op)) { + isDummy = true; + } else if (llvm::isa( + *op->getDialect())) { + op->emitError("Unsupported operation"); + assert(false && "Unsupported operation"); + } else { + isDummy = true; + } + + if (!isDummy) { + latticeRes.join(MANPLatticeValue{norm2SqEquiv}); + latticeRes.markOptimisticFixpoint(); + + llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv); + + op->setAttr("MANP", + mlir::IntegerAttr::get( + mlir::IntegerType::get( + op->getContext(), norm2Equiv.getBitWidth(), + mlir::IntegerType::SignednessSemantics::Unsigned), + norm2Equiv)); + + if (debug) { + op->emitRemark("Squared Minimal Arithmetic Noise Padding: ") + << APIntToStringValUnsigned(norm2SqEquiv) << "\n"; + } + } else { + latticeRes.join(MANPLatticeValue{}); + } + + return mlir::ChangeResult::Change; + } + +private: + bool debug; +}; +} // namespace + +namespace { +// For documentation see MANP.td +struct MANPPass : public MANPBase { + void runOnFunction() override { + mlir::FuncOp func = getFunction(); + + MANPAnalysis analysis(func->getContext(), debug); + analysis.run(func); + } + MANPPass() = delete; + MANPPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // end anonymous namespace + +// Create an instance of the Minimal Arithmetic Noise Padding analysis +// pass. If `debug` is true, for each operation, the pass emits a +// remark containing the squared Minimal Arithmetic Noise Padding of +// the equivalent dot operation. +std::unique_ptr createMANPPass(bool debug) { + return std::make_unique(debug); +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Dialect/HLFHE/CMakeLists.txt b/compiler/lib/Dialect/HLFHE/CMakeLists.txt index 7d59dce8e..4f7494893 100644 --- a/compiler/lib/Dialect/HLFHE/CMakeLists.txt +++ b/compiler/lib/Dialect/HLFHE/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(IR) \ No newline at end of file +add_subdirectory(Analysis) +add_subdirectory(IR) diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index 1a29b8b3c..0aed33462 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -14,6 +14,7 @@ target_link_libraries(zamacompiler LowLFHEDialect MidLFHEDialect HLFHEDialect + HLFHEDialectAnalysis MLIRIR MLIRLLVMIR diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 517f128ff..ac872e567 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Conversion/Utils/GlobalFHEContext.h" +#include "zamalang/Dialect/HLFHE/Analysis/MANP.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" @@ -30,6 +31,7 @@ enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; enum Action { ROUND_TRIP, + DEBUG_MANP, DUMP_MIDLFHE, DUMP_LOWLFHE, DUMP_STD, @@ -84,6 +86,9 @@ static llvm::cl::opt action( llvm::cl::values( clEnumValN(Action::ROUND_TRIP, "roundtrip", "Parse input module and regenerate textual representation")), + llvm::cl::values(clEnumValN( + Action::DEBUG_MANP, "debug-manp", + "Minimal Arithmetic Noise Padding for each HLFHE operation")), llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", "Lower to MidLFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", @@ -253,10 +258,28 @@ mlir::LogicalResult processInputBuffer( // a fallthrough mechanism to the next stage. Actions act as exit // points from the pipeline. switch (entryDialect) { - case EntryDialect::HLFHE: + case EntryDialect::HLFHE: { + bool debugMANP = (action == Action::DEBUG_MANP); + + mlir::LogicalResult manpRes = + mlir::zamalang::invokeMANPPass(module, debugMANP); + + if (action == Action::DEBUG_MANP) { + if (manpRes.failed()) { + mlir::zamalang::log_error() + << "Could not calculate Minimal Arithmetic Noise Padding"; + + if (!verifyDiagnostics) + return mlir::failure(); + } else { + return mlir::success(); + } + } + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) .failed()) return mlir::failure(); + } // fallthrough case EntryDialect::MIDLFHE: