feat: implement manual manp calculation for FHE.mul_eint

This commit is contained in:
Umut
2023-02-21 11:58:15 +01:00
parent bc69c87d62
commit d8eafabd22
2 changed files with 101 additions and 2 deletions

View File

@@ -149,6 +149,10 @@ struct FunctionToDag {
// If can't find weights return default leveled op
DEBUG("Replace Dot by LevelledOp on " << op);
}
if (auto mul = asMul(op)) {
addMul(dag, mul, encrypted_inputs, precision);
return;
}
if (auto max = asMax(op)) {
addMax(dag, max, encrypted_inputs, precision);
return;
@@ -219,6 +223,70 @@ struct FunctionToDag {
manp, slice(out_shape), comment);
}
void addMul(optimizer::Dag &dag, FHE::MulEintOp &mulOp, Inputs &inputs,
int precision) {
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
mlir::Value result = mulOp.getResult();
const std::vector<uint64_t> resultShape = getShape(result);
Operation *xOp = mulOp.a().getDefiningOp();
Operation *yOp = mulOp.b().getDefiningOp();
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
if (xOp != nullptr) {
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
xSmanp = xSmanpAttr.getValue();
}
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
if (yOp != nullptr) {
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
ySmanp = ySmanpAttr.getValue();
}
auto loc = loc_to_string(mulOp.getLoc());
auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc;
// (x + y) and (x - y)
const double addSubManp =
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
// tlu(v)
const double tluManp = 1;
// tlu(v1) - tlu(v2)
const double tluSubManp = sqrt(tluManp + tluManp);
// for tlus
const std::vector<std::uint64_t> unknownFunction;
// tlu(x + y)
auto addNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
addSubManp, slice(resultShape), comment);
auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision);
// tlu(x - y)
auto subNode =
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
addSubManp, slice(resultShape), comment);
auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
// tlu(x + y) - tlu(x - y)
const std::vector<concrete_optimizer::dag::OperatorIndex> subInputs = {
lhsTluNode, rhsTluNode};
index[result] =
dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost,
tluSubManp, slice(resultShape), comment);
}
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
int precision) {
mlir::Value result = maxOp.getResult();
@@ -346,6 +414,10 @@ struct FunctionToDag {
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
}
mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op);
}
mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) {
return llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op);
}

View File

@@ -487,6 +487,30 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUMul(sqNorm, eNorm);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of
/// `FHE.mul_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MulEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
const llvm::APInt tlu = {1, 1, false};
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
// this is not optimal as it can increase the resulting noise unnecessarily
return APIntUMax(beforeTLUs, result);
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.round` operation.
static llvm::APInt getSqMANP(
@@ -510,8 +534,8 @@ static llvm::APInt getSqMANP(
return eNorm;
}
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
/// that is equivalent to an `FHE.max_eint` operation.
/// Calculates the squared Minimal Arithmetic Noise Padding of
/// `FHE.max_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::MaxEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
@@ -1245,6 +1269,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
} else if (auto mulEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintOp, operands);
} else if (auto roundOp =
llvm::dyn_cast<mlir::concretelang::FHE::RoundEintOp>(op)) {
norm2SqEquiv = getSqMANP(roundOp, operands);