feat(compiler): Add pass for Minimal Arithmetic Noise Padding

This pass calculates the squared Minimal Arithmetic Noise Padding
(MANP) for each operation of a function and stores the result in an
integer attribute named "sqMANP". This metric is identical to the
squared 2-norm of the constant vector of an equivalent dot product
between a vector of encrypted integers resulting directly from an
encryption and a vector of plaintext constants.

The pass supports the following operations:

 - HLFHE.dot_eint_int
 - HLFHE.zero
 - HLFHE.add_eint_int
 - HLFHE.add_eint
 - HLFHE.sub_int_eint
 - HLFHE.mul_eint_int
 - HLFHE.apply_lookup_table

If any other operation is encountered, the pass conservatively
fails. The pass further makes the optimistic assumption that all
values passed to a function are either the direct result of an
encryption of a noise-refreshing operation.
This commit is contained in:
Andi Drebes
2021-09-03 12:01:56 +02:00
committed by Quentin Bourgerie
parent 30374ebb2c
commit ed762942c1
9 changed files with 643 additions and 3 deletions

View File

@@ -0,0 +1,7 @@
set(LLVM_TARGET_DEFINITIONS MANP.td)
mlir_tablegen(MANP.h.inc -gen-pass-decls -name Analysis)
mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis)
mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
add_public_tablegen_target(MANPPassIncGen)
add_mlir_doc(MANP GeneralPasses ./ -gen-pass-doc)

View File

@@ -0,0 +1,12 @@
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#include <mlir/Pass/Pass.h>
namespace mlir {
namespace zamalang {
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,95 @@
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP
include "mlir/Pass/PassBase.td"
def MANP : FunctionPass<"MANP"> {
let summary = "HLFHE Minimal Arithmetic Noise Padding Pass";
let description = [{
This pass calculates the Minimal Arithmetic Noise Padding
(MANP) for each operation of a function and stores the result in an
integer attribute named "MANP". This metric is identical to the
ceiled 2-norm of the constant vector of an equivalent dot product
between a vector of encrypted integers resulting directly from an
encryption and a vector of plaintext constants.
The pass supports the following operations:
- HLFHE.dot_eint_int
- HLFHE.zero
- HLFHE.add_eint_int
- HLFHE.add_eint
- HLFHE.sub_int_eint
- HLFHE.mul_eint_int
- HLFHE.apply_lookup_table
If any other operation is encountered, the pass conservatively
fails. The pass further makes the optimistic assumption that all
values passed to a function are either the direct result of an
encryption of a noise-refreshing operation.
Conceptually, the pass is equivalent to the three steps below:
1. Replace all arithmetic operations with an equivalent dot
operation
2. Merge resulting dot operations into a single, equivalent
dot operation
3. Calculate the 2-norm of the vector of plaintext constants
of the dot operation
with the following replacement rules:
- Function argument a -> HLFHE.dot_eint_int([a], [1])
- HLFHE.apply_lookup_table -> HLFHE.dot_eint_int([LUT result], [1])
- HLFHE.zero() -> HLFHE.dot_eint_int([encrypted 0], [1])
- HLFHE.add_eint_int(e, c) -> HLFHE.dot_eint_int([e, 1], [1, c])
- HLFHE.add_eint(e0, e1) -> HLFHE.dot_eint_int([e0, e1], [1, 1])
- HLFHE.sub_int_eint(c, e) -> HLFHE.dot_eint_int([e, c], [1, -1])
- HLFHE.mul_eint_int(e, c) -> HLFHE.dot_eint_int([e], [c])
Dependent dot operations, e.g.,
a = HLFHE.dot_eint_int([a0, a1, ...], [c0, c1, ...])
b = HLFHE.dot_eint_int([b0, b1, ...], [d0, d1, ...])
x = HLFHE.dot_eint_int([a, b, ...], [f0, f1, ...])
are merged as follows:
x = HLFHE.dot_eint_int([a0, a1, ..., b0, b1, ...],
[f0*c0, f0*c1, ..., f1*d0, f1*d1, ...])
However, the implementation does not explicitly create the
equivalent dot operations, but only accumulates the squared 2-norm
of the constant vector of the equivalent dot operation along the
edges of the data-flow graph composed by the operations in order to
calculate the final 2-norm for the final single dot operation above.
For the example above, this means that the pass calculates the
squared 2-norm of x, sqN(x) as:
sqN(a) = c0*c0 + c1*c1 + ...
sqN(b) = d0*d0 + d1*d1 + ...
sqN(x) = f0*f0*c0*c0 + f0*f0*c1*c1 + ... + f1*f1*d0*d0 + f1*f1*d1*d1 + ...
= f0*f0*sqN(a) + f1*f1*sqN(b)
This leads to the following rules to calculate the squared 2-norm
for the supported operations:
- Function argument -> 1
- HLFHE.apply_lookup_table -> 1
- HLFHE.zero() -> 1
- HLFHE.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
c0*c0*sqN(e0) + c1*c1*sqN(e1) + ...
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
The final, non-squared 2-norm of an operation is the square root of the
squared value rounded to the next highest integer.
}];
}
#endif

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Analysis)
add_subdirectory(IR)

View File

@@ -0,0 +1,15 @@
add_mlir_library(HLFHEDialectAnalysis
MANP.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
DEPENDS
HLFHEDialect
MANPPassIncGen
LINK_LIBS PUBLIC
MLIRIR
HLFHEDialect)
target_link_libraries(HLFHEDialectAnalysis PUBLIC MLIRIR)

View File

@@ -0,0 +1,485 @@
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
#include <limits>
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/SmallString.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h.inc>
namespace mlir {
namespace zamalang {
namespace {
// The `MANPLatticeValue` represents the squared Minimal Arithmetic
// Noise Padding for an operation using the squared 2-norm of an
// equivalent dot operation. This can either be an actual value if the
// values for its predecessors have been calculated beforehand or an
// unknown value otherwise.
struct MANPLatticeValue {
MANPLatticeValue(llvm::Optional<llvm::APInt> manp = {}) : manp(manp) {}
static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) {
return MANPLatticeValue();
}
static MANPLatticeValue getPessimisticValueState(mlir::Value value) {
// Function arguments are assumed to require a Minimal Arithmetic
// Noise Padding with a 2-norm of 1.
//
// TODO: Provide a mechanism to propagate Minimal Arithmetic Noise
// Padding across function calls.
if (value.isa<mlir::BlockArgument>() &&
(value.getType().isa<mlir::zamalang::HLFHE::EncryptedIntegerType>() ||
(value.getType().isa<mlir::TensorType>() &&
value.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()))) {
return MANPLatticeValue(llvm::APInt{1, 1, false});
} else {
// All other operations have an unknown Minimal Arithmetic Noise
// Padding until an value for all predecessors has been
// calculated.
return MANPLatticeValue();
}
}
bool operator==(const MANPLatticeValue &rhs) const {
return this->manp == rhs.manp;
}
// Required by `mlir::LatticeElement::join()`, but should never be
// invoked, as `MANPAnalysis::visitOperation()` takes care of
// combining the squared Minimal Arithmetic Noise Padding of
// operands into the Minimal Arithmetic Noise Padding of the result.
static MANPLatticeValue join(const MANPLatticeValue &lhs,
const MANPLatticeValue &rhs) {
assert(false && "Minimal Arithmetic Noise Padding values can only be "
"combined sensibly when the combining operation is known");
return MANPLatticeValue{};
}
llvm::Optional<llvm::APInt> getMANP() { return manp; }
protected:
llvm::Optional<llvm::APInt> manp;
};
// Checks if `lhs` is less than `rhs`, where both values are assumed
// to be positive. The bit width of the smaller `APInt` is extended
// before comparison via `APInt::ult`.
static bool APIntWidthExtendULT(const llvm::APInt &lhs,
const llvm::APInt &rhs) {
if (lhs.getBitWidth() < rhs.getBitWidth())
return lhs.zext(rhs.getBitWidth()).ult(rhs);
else if (lhs.getBitWidth() > rhs.getBitWidth())
return lhs.ult(rhs.zext(lhs.getBitWidth()));
else
return lhs.ult(rhs);
}
// Adds two `APInt` values, where both values are assumed to be
// positive. The bit width of the operands is extended in order to
// guarantee that the sum fits into the resulting `APInt`.
static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs,
const llvm::APInt &rhs) {
unsigned maxBits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(std::numeric_limits<unsigned>::max() - maxBits > 1);
unsigned targetWidth = maxBits + 1;
return lhs.zext(targetWidth) + rhs.zext(targetWidth);
}
// Multiplies two `APInt` values, where both values are assumed to be
// positive. The bit width of the operands is extended in order to
// guarantee that the product fits into the resulting `APInt`.
static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
const llvm::APInt &rhs) {
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(std::numeric_limits<unsigned>::max() -
std::max(lhs.getBitWidth(), rhs.getBitWidth()) >
std::min(lhs.getBitWidth(), rhs.getBitWidth()) &&
"Required number of bits cannot be represented with an APInt");
unsigned targetWidth = lhs.getBitWidth() + rhs.getBitWidth();
return lhs.zext(targetWidth) * rhs.zext(targetWidth);
}
// Calculates the square of `i`. The bit width `i` is extended in
// order to guarantee that the product fits into the resulting
// `APInt`.
static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(i.getBitWidth() < std::numeric_limits<unsigned>::max() / 2 &&
"Required number of bits cannot be represented with an APInt");
llvm::APInt ie = i.zext(2 * i.getBitWidth());
return ie * ie;
}
// Calculates the square root of `i` and rounds it to the next highest
// integer value (i.e., the square of the result is guaranteed to be
// greater or equal to `i`).
static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
llvm::APInt res = i.sqrt();
llvm::APInt resSq = APIntWidthExtendUSq(res);
if (APIntWidthExtendULT(resSq, i))
return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false});
else
return res;
}
// Returns a string representation of `i` assuming that `i` is an
// unsigned value.
static std::string APIntToStringValUnsigned(const llvm::APInt &i) {
llvm::SmallString<32> s;
i.toStringUnsigned(s);
return std::string(s.c_str());
}
// Calculates the square of the 2-norm of a tensor initialized with a
// dense matrix of constant, signless integers. Aborts if the value
// type or initialization of of `cstOp` is incorrect.
static llvm::APInt denseCstTensorNorm2Sq(mlir::ConstantOp cstOp) {
mlir::DenseIntElementsAttr denseVals =
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
assert(denseVals && cstOp.getType().isa<mlir::TensorType>() &&
"Constant must be a tensor initialized with `dense`");
mlir::TensorType tensorType = cstOp.getType().cast<mlir::TensorType>();
assert(tensorType.getElementType().isSignlessInteger() &&
"Can only handle tensors with signless integer elements");
mlir::IntegerType elementType =
tensorType.getElementType().cast<mlir::IntegerType>();
llvm::APInt accu{1, 0, false};
for (llvm::APInt val : denseVals.getIntValues()) {
llvm::APInt valSq = APIntWidthExtendUSq(val);
accu = APIntWidthExtendUAdd(accu, valSq);
}
return accu;
}
// Calculates (T)ceil(log2f(v))
// TODO: Replace with some fancy bit twiddling hack
template <typename T> static T ceilLog2(const T v) {
T tmp = v;
T log2 = 0;
while (tmp >>= 1)
log2++;
// If more than MSB set, round to next highest power of 2
if (v & ~((T)1 << log2))
log2 += 1;
return log2;
}
// Calculates the square of the 2-norm of a 1D tensor of signless
// integers by conservatively assuming that the dynamic values are the
// maximum for the integer width. Aborts if the tensor type `tTy` is
// incorrect.
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
assert(tTy && tTy.getElementType().isSignlessInteger() &&
tTy.hasStaticShape() && tTy.getRank() == 1 &&
"Plaintext operand must be a statically shaped 1D tensor of integers");
// Make sure the log2 of the number of elements fits into an
// unsigned
assert(std::numeric_limits<unsigned>::max() > 8 * sizeof(uint64_t));
unsigned elWidth = tTy.getElementTypeBitWidth();
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
// Calculate number of bits for APInt to store number of elements
uint64_t nElts = (uint64_t)tTy.getNumElements();
assert(std::numeric_limits<int64_t>::max() - nElts > 1);
unsigned nEltsBits = (unsigned)ceilLog2(nElts + 1);
llvm::APInt nEltsAP{nEltsBits, nElts, false};
return APIntWidthExtendUMul(maxValSq, nEltsAP);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
// `HLFHE.dot_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::Dot op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(op->getOpOperand(0).get().isa<mlir::BlockArgument>() &&
"Only dot operations with tensors that are function arguments are "
"currently supported");
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
if (cstOp) {
// Dot product between a vector of encrypted integers and a vector
// of plaintext constants -> return 2-norm of constant vector
return denseCstTensorNorm2Sq(cstOp);
} else {
// Dot product between a vector of encrypted integers and a vector
// of dynamic plaintext values -> conservatively assume that all
// the values are the maximum possible value for the integer's
// width
mlir::TensorType tTy = op->getOpOperand(1)
.get()
.getType()
.dyn_cast_or_null<mlir::TensorType>();
return denseDynTensorNorm2Sq(tTy);
}
}
// Returns the squared 2-norm for a dynamic integer by conservatively
// assuming that the integer's value is the maximum for the integer
// width.
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
assert(t.isSignlessInteger() && "Type must be a signless integer type");
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false};
maxVal <<= t.getIntOrFloatBitWidth();
return APIntWidthExtendUSq(maxVal);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
// `HLFHE.add_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::AddEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::Type iTy = op->getOpOperand(1).get().getType();
assert(iTy.isSignlessInteger() &&
"Only additions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
if (cstOp) {
// For a constant operand use actual constant to calculate 2-norm
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.add_eint` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::AddEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
return APIntWidthExtendUAdd(a, b);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.sub_int_eint` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::SubIntEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::Type iTy = op->getOpOperand(0).get().getType();
assert(iTy.isSignlessInteger() &&
"Only subtractions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(0).get().getDefiningOp());
if (cstOp) {
// For constant plaintext operands simply use the constant value
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::MulEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::Type iTy = op->getOpOperand(1).get().getType();
assert(iTy.isSignlessInteger() &&
"Only multiplications with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
if (cstOp) {
// For a constant operand use actual constant to calculate 2-norm
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUMul(sqNorm, eNorm);
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
: debug(debug), mlir::ForwardDataFlowAnalysis<MANPLatticeValue>(ctx) {}
~MANPAnalysis() override = default;
mlir::ChangeResult visitOperation(
mlir::Operation *op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operands) final {
mlir::LatticeElement<MANPLatticeValue> &latticeRes =
getLatticeElement(op->getResult(0));
bool isDummy = false;
llvm::APInt norm2SqEquiv;
if (auto dotOp = llvm::dyn_cast<mlir::zamalang::HLFHE::Dot>(op)) {
norm2SqEquiv = getSqMANP(dotOp, operands);
} else if (auto addEintIntOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::AddEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(addEintIntOp, operands);
} else if (auto addEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::AddEintOp>(op)) {
norm2SqEquiv = getSqMANP(addEintOp, operands);
} else if (auto subIntEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::SubIntEintOp>(op)) {
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::MulEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
} else if (llvm::isa<mlir::ConstantOp>(op)) {
isDummy = true;
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(
*op->getDialect())) {
op->emitError("Unsupported operation");
assert(false && "Unsupported operation");
} else {
isDummy = true;
}
if (!isDummy) {
latticeRes.join(MANPLatticeValue{norm2SqEquiv});
latticeRes.markOptimisticFixpoint();
llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv);
op->setAttr("MANP",
mlir::IntegerAttr::get(
mlir::IntegerType::get(
op->getContext(), norm2Equiv.getBitWidth(),
mlir::IntegerType::SignednessSemantics::Unsigned),
norm2Equiv));
if (debug) {
op->emitRemark("Squared Minimal Arithmetic Noise Padding: ")
<< APIntToStringValUnsigned(norm2SqEquiv) << "\n";
}
} else {
latticeRes.join(MANPLatticeValue{});
}
return mlir::ChangeResult::Change;
}
private:
bool debug;
};
} // namespace
namespace {
// For documentation see MANP.td
struct MANPPass : public MANPBase<MANPPass> {
void runOnFunction() override {
mlir::FuncOp func = getFunction();
MANPAnalysis analysis(func->getContext(), debug);
analysis.run(func);
}
MANPPass() = delete;
MANPPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
// Create an instance of the Minimal Arithmetic Noise Padding analysis
// pass. If `debug` is true, for each operation, the pass emits a
// remark containing the squared Minimal Arithmetic Noise Padding of
// the equivalent dot operation.
std::unique_ptr<mlir::Pass> createMANPPass(bool debug) {
return std::make_unique<MANPPass>(debug);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Analysis)
add_subdirectory(IR)

View File

@@ -14,6 +14,7 @@ target_link_libraries(zamacompiler
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
HLFHEDialectAnalysis
MLIRIR
MLIRLLVMIR

View File

@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
#include "zamalang/Dialect/HLFHE/Analysis/MANP.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
@@ -30,6 +31,7 @@ enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DEBUG_MANP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
DUMP_STD,
@@ -84,6 +86,9 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(
clEnumValN(Action::ROUND_TRIP, "roundtrip",
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(
Action::DEBUG_MANP, "debug-manp",
"Minimal Arithmetic Noise Padding for each HLFHE operation")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
@@ -253,10 +258,28 @@ mlir::LogicalResult processInputBuffer(
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
case EntryDialect::HLFHE: {
bool debugMANP = (action == Action::DEBUG_MANP);
mlir::LogicalResult manpRes =
mlir::zamalang::invokeMANPPass(module, debugMANP);
if (action == Action::DEBUG_MANP) {
if (manpRes.failed()) {
mlir::zamalang::log_error()
<< "Could not calculate Minimal Arithmetic Noise Padding";
if (!verifyDiagnostics)
return mlir::failure();
} else {
return mlir::success();
}
}
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
}
// fallthrough
case EntryDialect::MIDLFHE: