diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 7a02bad8b..1725dde1b 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -149,6 +149,10 @@ struct FunctionToDag { // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); } + if (auto mul = asMul(op)) { + addMul(dag, mul, encrypted_inputs, precision); + return; + } if (auto max = asMax(op)) { addMax(dag, max, encrypted_inputs, precision); return; @@ -219,6 +223,70 @@ struct FunctionToDag { manp, slice(out_shape), comment); } + void addMul(optimizer::Dag &dag, FHE::MulEintOp &mulOp, Inputs &inputs, + int precision) { + + // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) + + mlir::Value result = mulOp.getResult(); + const std::vector resultShape = getShape(result); + + Operation *xOp = mulOp.a().getDefiningOp(); + Operation *yOp = mulOp.b().getDefiningOp(); + + const double fixedCost = NEGLIGIBLE_COMPLEXITY; + const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; + + llvm::APInt xSmanp = llvm::APInt{1, 1, false}; + if (xOp != nullptr) { + const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); + assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); + xSmanp = xSmanpAttr.getValue(); + } + + llvm::APInt ySmanp = llvm::APInt{1, 1, false}; + if (yOp != nullptr) { + const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); + assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); + ySmanp = ySmanpAttr.getValue(); + } + + auto loc = loc_to_string(mulOp.getLoc()); + auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc; + + // (x + y) and (x - y) + const double addSubManp = + sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); + + // tlu(v) + const double tluManp = 1; + + // tlu(v1) - tlu(v2) + const double tluSubManp = sqrt(tluManp + tluManp); + + // for tlus + const std::vector unknownFunction; + + // tlu(x + y) + auto addNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); + auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); + + // tlu(x - y) + auto subNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); + auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + + // tlu(x + y) - tlu(x - y) + const std::vector subInputs = { + lhsTluNode, rhsTluNode}; + index[result] = + dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, + tluSubManp, slice(resultShape), comment); + } + void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs, int precision) { mlir::Value result = maxOp.getResult(); @@ -346,6 +414,10 @@ struct FunctionToDag { return llvm::dyn_cast(op); } + mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) { return llvm::dyn_cast(op); } diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index d69de6079..0f7ac4a48 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -487,6 +487,30 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(sqNorm, eNorm); } +/// Calculates the squared Minimal Arithmetic Noise Padding of +/// `FHE.mul_eint` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::MulEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) + + const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue(); + const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue(); + + const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y); + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); + + // this is not optimal as it can increase the resulting noise unnecessarily + return APIntUMax(beforeTLUs, result); +} + /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation /// that is equivalent to an `FHE.round` operation. static llvm::APInt getSqMANP( @@ -510,8 +534,8 @@ static llvm::APInt getSqMANP( return eNorm; } -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.max_eint` operation. +/// Calculates the squared Minimal Arithmetic Noise Padding of +/// `FHE.max_eint` operation. static llvm::APInt getSqMANP( mlir::concretelang::FHE::MaxEintOp op, llvm::ArrayRef *> operandMANPs) { @@ -1245,6 +1269,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto mulEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (auto mulEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(mulEintOp, operands); } else if (auto roundOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(roundOp, operands);