mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add pass for Minimal Arithmetic Noise Padding
This pass calculates the squared Minimal Arithmetic Noise Padding (MANP) for each operation of a function and stores the result in an integer attribute named "sqMANP". This metric is identical to the squared 2-norm of the constant vector of an equivalent dot product between a vector of encrypted integers resulting directly from an encryption and a vector of plaintext constants. The pass supports the following operations: - HLFHE.dot_eint_int - HLFHE.zero - HLFHE.add_eint_int - HLFHE.add_eint - HLFHE.sub_int_eint - HLFHE.mul_eint_int - HLFHE.apply_lookup_table If any other operation is encountered, the pass conservatively fails. The pass further makes the optimistic assumption that all values passed to a function are either the direct result of an encryption of a noise-refreshing operation.
This commit is contained in:
committed by
Quentin Bourgerie
parent
30374ebb2c
commit
ed762942c1
@@ -0,0 +1,7 @@
|
||||
set(LLVM_TARGET_DEFINITIONS MANP.td)
|
||||
mlir_tablegen(MANP.h.inc -gen-pass-decls -name Analysis)
|
||||
mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis)
|
||||
mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
|
||||
add_public_tablegen_target(MANPPassIncGen)
|
||||
|
||||
add_mlir_doc(MANP GeneralPasses ./ -gen-pass-doc)
|
||||
12
compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h
Normal file
12
compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
95
compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td
Normal file
95
compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td
Normal file
@@ -0,0 +1,95 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP
|
||||
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MANP : FunctionPass<"MANP"> {
|
||||
let summary = "HLFHE Minimal Arithmetic Noise Padding Pass";
|
||||
let description = [{
|
||||
This pass calculates the Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation of a function and stores the result in an
|
||||
integer attribute named "MANP". This metric is identical to the
|
||||
ceiled 2-norm of the constant vector of an equivalent dot product
|
||||
between a vector of encrypted integers resulting directly from an
|
||||
encryption and a vector of plaintext constants.
|
||||
|
||||
The pass supports the following operations:
|
||||
|
||||
- HLFHE.dot_eint_int
|
||||
- HLFHE.zero
|
||||
- HLFHE.add_eint_int
|
||||
- HLFHE.add_eint
|
||||
- HLFHE.sub_int_eint
|
||||
- HLFHE.mul_eint_int
|
||||
- HLFHE.apply_lookup_table
|
||||
|
||||
If any other operation is encountered, the pass conservatively
|
||||
fails. The pass further makes the optimistic assumption that all
|
||||
values passed to a function are either the direct result of an
|
||||
encryption of a noise-refreshing operation.
|
||||
|
||||
Conceptually, the pass is equivalent to the three steps below:
|
||||
|
||||
1. Replace all arithmetic operations with an equivalent dot
|
||||
operation
|
||||
|
||||
2. Merge resulting dot operations into a single, equivalent
|
||||
dot operation
|
||||
|
||||
3. Calculate the 2-norm of the vector of plaintext constants
|
||||
of the dot operation
|
||||
|
||||
with the following replacement rules:
|
||||
|
||||
- Function argument a -> HLFHE.dot_eint_int([a], [1])
|
||||
- HLFHE.apply_lookup_table -> HLFHE.dot_eint_int([LUT result], [1])
|
||||
- HLFHE.zero() -> HLFHE.dot_eint_int([encrypted 0], [1])
|
||||
- HLFHE.add_eint_int(e, c) -> HLFHE.dot_eint_int([e, 1], [1, c])
|
||||
- HLFHE.add_eint(e0, e1) -> HLFHE.dot_eint_int([e0, e1], [1, 1])
|
||||
- HLFHE.sub_int_eint(c, e) -> HLFHE.dot_eint_int([e, c], [1, -1])
|
||||
- HLFHE.mul_eint_int(e, c) -> HLFHE.dot_eint_int([e], [c])
|
||||
|
||||
Dependent dot operations, e.g.,
|
||||
|
||||
a = HLFHE.dot_eint_int([a0, a1, ...], [c0, c1, ...])
|
||||
b = HLFHE.dot_eint_int([b0, b1, ...], [d0, d1, ...])
|
||||
x = HLFHE.dot_eint_int([a, b, ...], [f0, f1, ...])
|
||||
|
||||
are merged as follows:
|
||||
|
||||
x = HLFHE.dot_eint_int([a0, a1, ..., b0, b1, ...],
|
||||
[f0*c0, f0*c1, ..., f1*d0, f1*d1, ...])
|
||||
|
||||
However, the implementation does not explicitly create the
|
||||
equivalent dot operations, but only accumulates the squared 2-norm
|
||||
of the constant vector of the equivalent dot operation along the
|
||||
edges of the data-flow graph composed by the operations in order to
|
||||
calculate the final 2-norm for the final single dot operation above.
|
||||
|
||||
For the example above, this means that the pass calculates the
|
||||
squared 2-norm of x, sqN(x) as:
|
||||
|
||||
sqN(a) = c0*c0 + c1*c1 + ...
|
||||
sqN(b) = d0*d0 + d1*d1 + ...
|
||||
sqN(x) = f0*f0*c0*c0 + f0*f0*c1*c1 + ... + f1*f1*d0*d0 + f1*f1*d1*d1 + ...
|
||||
= f0*f0*sqN(a) + f1*f1*sqN(b)
|
||||
|
||||
This leads to the following rules to calculate the squared 2-norm
|
||||
for the supported operations:
|
||||
|
||||
- Function argument -> 1
|
||||
- HLFHE.apply_lookup_table -> 1
|
||||
- HLFHE.zero() -> 1
|
||||
- HLFHE.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
|
||||
c0*c0*sqN(e0) + c1*c1*sqN(e1) + ...
|
||||
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
|
||||
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
|
||||
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
|
||||
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
|
||||
|
||||
The final, non-squared 2-norm of an operation is the square root of the
|
||||
squared value rounded to the next highest integer.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
|
||||
15
compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt
Normal file
15
compiler/lib/Dialect/HLFHE/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_mlir_library(HLFHEDialectAnalysis
|
||||
MANP.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
|
||||
|
||||
DEPENDS
|
||||
HLFHEDialect
|
||||
MANPPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
HLFHEDialect)
|
||||
|
||||
target_link_libraries(HLFHEDialectAnalysis PUBLIC MLIRIR)
|
||||
485
compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp
Normal file
485
compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp
Normal file
@@ -0,0 +1,485 @@
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
|
||||
|
||||
#include <limits>
|
||||
#include <llvm/ADT/APInt.h>
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/SmallString.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace {
|
||||
// The `MANPLatticeValue` represents the squared Minimal Arithmetic
|
||||
// Noise Padding for an operation using the squared 2-norm of an
|
||||
// equivalent dot operation. This can either be an actual value if the
|
||||
// values for its predecessors have been calculated beforehand or an
|
||||
// unknown value otherwise.
|
||||
struct MANPLatticeValue {
|
||||
MANPLatticeValue(llvm::Optional<llvm::APInt> manp = {}) : manp(manp) {}
|
||||
|
||||
static MANPLatticeValue getPessimisticValueState(mlir::MLIRContext *context) {
|
||||
return MANPLatticeValue();
|
||||
}
|
||||
|
||||
static MANPLatticeValue getPessimisticValueState(mlir::Value value) {
|
||||
// Function arguments are assumed to require a Minimal Arithmetic
|
||||
// Noise Padding with a 2-norm of 1.
|
||||
//
|
||||
// TODO: Provide a mechanism to propagate Minimal Arithmetic Noise
|
||||
// Padding across function calls.
|
||||
if (value.isa<mlir::BlockArgument>() &&
|
||||
(value.getType().isa<mlir::zamalang::HLFHE::EncryptedIntegerType>() ||
|
||||
(value.getType().isa<mlir::TensorType>() &&
|
||||
value.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()))) {
|
||||
return MANPLatticeValue(llvm::APInt{1, 1, false});
|
||||
} else {
|
||||
// All other operations have an unknown Minimal Arithmetic Noise
|
||||
// Padding until an value for all predecessors has been
|
||||
// calculated.
|
||||
return MANPLatticeValue();
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const MANPLatticeValue &rhs) const {
|
||||
return this->manp == rhs.manp;
|
||||
}
|
||||
|
||||
// Required by `mlir::LatticeElement::join()`, but should never be
|
||||
// invoked, as `MANPAnalysis::visitOperation()` takes care of
|
||||
// combining the squared Minimal Arithmetic Noise Padding of
|
||||
// operands into the Minimal Arithmetic Noise Padding of the result.
|
||||
static MANPLatticeValue join(const MANPLatticeValue &lhs,
|
||||
const MANPLatticeValue &rhs) {
|
||||
assert(false && "Minimal Arithmetic Noise Padding values can only be "
|
||||
"combined sensibly when the combining operation is known");
|
||||
return MANPLatticeValue{};
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::APInt> getMANP() { return manp; }
|
||||
|
||||
protected:
|
||||
llvm::Optional<llvm::APInt> manp;
|
||||
};
|
||||
|
||||
// Checks if `lhs` is less than `rhs`, where both values are assumed
|
||||
// to be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::ult`.
|
||||
static bool APIntWidthExtendULT(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
if (lhs.getBitWidth() < rhs.getBitWidth())
|
||||
return lhs.zext(rhs.getBitWidth()).ult(rhs);
|
||||
else if (lhs.getBitWidth() > rhs.getBitWidth())
|
||||
return lhs.ult(rhs.zext(lhs.getBitWidth()));
|
||||
else
|
||||
return lhs.ult(rhs);
|
||||
}
|
||||
|
||||
// Adds two `APInt` values, where both values are assumed to be
|
||||
// positive. The bit width of the operands is extended in order to
|
||||
// guarantee that the sum fits into the resulting `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUAdd(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
unsigned maxBits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
||||
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(std::numeric_limits<unsigned>::max() - maxBits > 1);
|
||||
|
||||
unsigned targetWidth = maxBits + 1;
|
||||
return lhs.zext(targetWidth) + rhs.zext(targetWidth);
|
||||
}
|
||||
|
||||
// Multiplies two `APInt` values, where both values are assumed to be
|
||||
// positive. The bit width of the operands is extended in order to
|
||||
// guarantee that the product fits into the resulting `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(std::numeric_limits<unsigned>::max() -
|
||||
std::max(lhs.getBitWidth(), rhs.getBitWidth()) >
|
||||
std::min(lhs.getBitWidth(), rhs.getBitWidth()) &&
|
||||
"Required number of bits cannot be represented with an APInt");
|
||||
|
||||
unsigned targetWidth = lhs.getBitWidth() + rhs.getBitWidth();
|
||||
return lhs.zext(targetWidth) * rhs.zext(targetWidth);
|
||||
}
|
||||
|
||||
// Calculates the square of `i`. The bit width `i` is extended in
|
||||
// order to guarantee that the product fits into the resulting
|
||||
// `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(i.getBitWidth() < std::numeric_limits<unsigned>::max() / 2 &&
|
||||
"Required number of bits cannot be represented with an APInt");
|
||||
|
||||
llvm::APInt ie = i.zext(2 * i.getBitWidth());
|
||||
|
||||
return ie * ie;
|
||||
}
|
||||
|
||||
// Calculates the square root of `i` and rounds it to the next highest
|
||||
// integer value (i.e., the square of the result is guaranteed to be
|
||||
// greater or equal to `i`).
|
||||
static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
|
||||
llvm::APInt res = i.sqrt();
|
||||
llvm::APInt resSq = APIntWidthExtendUSq(res);
|
||||
|
||||
if (APIntWidthExtendULT(resSq, i))
|
||||
return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false});
|
||||
else
|
||||
return res;
|
||||
}
|
||||
|
||||
// Returns a string representation of `i` assuming that `i` is an
|
||||
// unsigned value.
|
||||
static std::string APIntToStringValUnsigned(const llvm::APInt &i) {
|
||||
llvm::SmallString<32> s;
|
||||
i.toStringUnsigned(s);
|
||||
return std::string(s.c_str());
|
||||
}
|
||||
|
||||
// Calculates the square of the 2-norm of a tensor initialized with a
|
||||
// dense matrix of constant, signless integers. Aborts if the value
|
||||
// type or initialization of of `cstOp` is incorrect.
|
||||
static llvm::APInt denseCstTensorNorm2Sq(mlir::ConstantOp cstOp) {
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
|
||||
assert(denseVals && cstOp.getType().isa<mlir::TensorType>() &&
|
||||
"Constant must be a tensor initialized with `dense`");
|
||||
|
||||
mlir::TensorType tensorType = cstOp.getType().cast<mlir::TensorType>();
|
||||
|
||||
assert(tensorType.getElementType().isSignlessInteger() &&
|
||||
"Can only handle tensors with signless integer elements");
|
||||
|
||||
mlir::IntegerType elementType =
|
||||
tensorType.getElementType().cast<mlir::IntegerType>();
|
||||
|
||||
llvm::APInt accu{1, 0, false};
|
||||
|
||||
for (llvm::APInt val : denseVals.getIntValues()) {
|
||||
llvm::APInt valSq = APIntWidthExtendUSq(val);
|
||||
accu = APIntWidthExtendUAdd(accu, valSq);
|
||||
}
|
||||
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Calculates (T)ceil(log2f(v))
|
||||
// TODO: Replace with some fancy bit twiddling hack
|
||||
template <typename T> static T ceilLog2(const T v) {
|
||||
T tmp = v;
|
||||
T log2 = 0;
|
||||
|
||||
while (tmp >>= 1)
|
||||
log2++;
|
||||
|
||||
// If more than MSB set, round to next highest power of 2
|
||||
if (v & ~((T)1 << log2))
|
||||
log2 += 1;
|
||||
|
||||
return log2;
|
||||
}
|
||||
|
||||
// Calculates the square of the 2-norm of a 1D tensor of signless
|
||||
// integers by conservatively assuming that the dynamic values are the
|
||||
// maximum for the integer width. Aborts if the tensor type `tTy` is
|
||||
// incorrect.
|
||||
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
|
||||
assert(tTy && tTy.getElementType().isSignlessInteger() &&
|
||||
tTy.hasStaticShape() && tTy.getRank() == 1 &&
|
||||
"Plaintext operand must be a statically shaped 1D tensor of integers");
|
||||
|
||||
// Make sure the log2 of the number of elements fits into an
|
||||
// unsigned
|
||||
assert(std::numeric_limits<unsigned>::max() > 8 * sizeof(uint64_t));
|
||||
|
||||
unsigned elWidth = tTy.getElementTypeBitWidth();
|
||||
|
||||
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
|
||||
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
|
||||
|
||||
// Calculate number of bits for APInt to store number of elements
|
||||
uint64_t nElts = (uint64_t)tTy.getNumElements();
|
||||
assert(std::numeric_limits<int64_t>::max() - nElts > 1);
|
||||
unsigned nEltsBits = (unsigned)ceilLog2(nElts + 1);
|
||||
|
||||
llvm::APInt nEltsAP{nEltsBits, nElts, false};
|
||||
|
||||
return APIntWidthExtendUMul(maxValSq, nEltsAP);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `HLFHE.dot_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::Dot op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(op->getOpOperand(0).get().isa<mlir::BlockArgument>() &&
|
||||
"Only dot operations with tensors that are function arguments are "
|
||||
"currently supported");
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
|
||||
if (cstOp) {
|
||||
// Dot product between a vector of encrypted integers and a vector
|
||||
// of plaintext constants -> return 2-norm of constant vector
|
||||
return denseCstTensorNorm2Sq(cstOp);
|
||||
} else {
|
||||
// Dot product between a vector of encrypted integers and a vector
|
||||
// of dynamic plaintext values -> conservatively assume that all
|
||||
// the values are the maximum possible value for the integer's
|
||||
// width
|
||||
mlir::TensorType tTy = op->getOpOperand(1)
|
||||
.get()
|
||||
.getType()
|
||||
.dyn_cast_or_null<mlir::TensorType>();
|
||||
|
||||
return denseDynTensorNorm2Sq(tTy);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm for a dynamic integer by conservatively
|
||||
// assuming that the integer's value is the maximum for the integer
|
||||
// width.
|
||||
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
assert(t.isSignlessInteger() && "Type must be a signless integer type");
|
||||
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
|
||||
|
||||
llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false};
|
||||
maxVal <<= t.getIntOrFloatBitWidth();
|
||||
return APIntWidthExtendUSq(maxVal);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `HLFHE.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
mlir::Type iTy = op->getOpOperand(1).get().getType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only additions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
if (cstOp) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.add_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::AddEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
mlir::Type iTy = op->getOpOperand(0).get().getType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only subtractions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(0).get().getDefiningOp());
|
||||
|
||||
if (cstOp) {
|
||||
// For constant plaintext operands simply use the constant value
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
mlir::Type iTy = op->getOpOperand(1).get().getType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only multiplications with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
if (cstOp) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
: debug(debug), mlir::ForwardDataFlowAnalysis<MANPLatticeValue>(ctx) {}
|
||||
|
||||
~MANPAnalysis() override = default;
|
||||
|
||||
mlir::ChangeResult visitOperation(
|
||||
mlir::Operation *op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operands) final {
|
||||
mlir::LatticeElement<MANPLatticeValue> &latticeRes =
|
||||
getLatticeElement(op->getResult(0));
|
||||
bool isDummy = false;
|
||||
llvm::APInt norm2SqEquiv;
|
||||
|
||||
if (auto dotOp = llvm::dyn_cast<mlir::zamalang::HLFHE::Dot>(op)) {
|
||||
norm2SqEquiv = getSqMANP(dotOp, operands);
|
||||
} else if (auto addEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::AddEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(addEintIntOp, operands);
|
||||
} else if (auto addEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::AddEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(addEintOp, operands);
|
||||
} else if (auto subIntEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::SubIntEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
} else if (llvm::isa<mlir::ConstantOp>(op)) {
|
||||
isDummy = true;
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(
|
||||
*op->getDialect())) {
|
||||
op->emitError("Unsupported operation");
|
||||
assert(false && "Unsupported operation");
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
|
||||
if (!isDummy) {
|
||||
latticeRes.join(MANPLatticeValue{norm2SqEquiv});
|
||||
latticeRes.markOptimisticFixpoint();
|
||||
|
||||
llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv);
|
||||
|
||||
op->setAttr("MANP",
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(
|
||||
op->getContext(), norm2Equiv.getBitWidth(),
|
||||
mlir::IntegerType::SignednessSemantics::Unsigned),
|
||||
norm2Equiv));
|
||||
|
||||
if (debug) {
|
||||
op->emitRemark("Squared Minimal Arithmetic Noise Padding: ")
|
||||
<< APIntToStringValUnsigned(norm2SqEquiv) << "\n";
|
||||
}
|
||||
} else {
|
||||
latticeRes.join(MANPLatticeValue{});
|
||||
}
|
||||
|
||||
return mlir::ChangeResult::Change;
|
||||
}
|
||||
|
||||
private:
|
||||
bool debug;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// For documentation see MANP.td
|
||||
struct MANPPass : public MANPBase<MANPPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp func = getFunction();
|
||||
|
||||
MANPAnalysis analysis(func->getContext(), debug);
|
||||
analysis.run(func);
|
||||
}
|
||||
MANPPass() = delete;
|
||||
MANPPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Create an instance of the Minimal Arithmetic Noise Padding analysis
|
||||
// pass. If `debug` is true, for each operation, the pass emits a
|
||||
// remark containing the squared Minimal Arithmetic Noise Padding of
|
||||
// the equivalent dot operation.
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug) {
|
||||
return std::make_unique<MANPPass>(debug);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
|
||||
@@ -14,6 +14,7 @@ target_link_libraries(zamacompiler
|
||||
LowLFHEDialect
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
HLFHEDialectAnalysis
|
||||
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "zamalang/Dialect/HLFHE/Analysis/MANP.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
@@ -30,6 +31,7 @@ enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DEBUG_MANP,
|
||||
DUMP_MIDLFHE,
|
||||
DUMP_LOWLFHE,
|
||||
DUMP_STD,
|
||||
@@ -84,6 +86,9 @@ static llvm::cl::opt<enum Action> action(
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::ROUND_TRIP, "roundtrip",
|
||||
"Parse input module and regenerate textual representation")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DEBUG_MANP, "debug-manp",
|
||||
"Minimal Arithmetic Noise Padding for each HLFHE operation")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
|
||||
"Lower to MidLFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
|
||||
@@ -253,10 +258,28 @@ mlir::LogicalResult processInputBuffer(
|
||||
// a fallthrough mechanism to the next stage. Actions act as exit
|
||||
// points from the pipeline.
|
||||
switch (entryDialect) {
|
||||
case EntryDialect::HLFHE:
|
||||
case EntryDialect::HLFHE: {
|
||||
bool debugMANP = (action == Action::DEBUG_MANP);
|
||||
|
||||
mlir::LogicalResult manpRes =
|
||||
mlir::zamalang::invokeMANPPass(module, debugMANP);
|
||||
|
||||
if (action == Action::DEBUG_MANP) {
|
||||
if (manpRes.failed()) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Could not calculate Minimal Arithmetic Noise Padding";
|
||||
|
||||
if (!verifyDiagnostics)
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::MIDLFHE:
|
||||
|
||||
Reference in New Issue
Block a user