mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: implement manual manp calculation for FHE.mul_eint
This commit is contained in:
@@ -149,6 +149,10 @@ struct FunctionToDag {
|
||||
// If can't find weights return default leveled op
|
||||
DEBUG("Replace Dot by LevelledOp on " << op);
|
||||
}
|
||||
if (auto mul = asMul(op)) {
|
||||
addMul(dag, mul, encrypted_inputs, precision);
|
||||
return;
|
||||
}
|
||||
if (auto max = asMax(op)) {
|
||||
addMax(dag, max, encrypted_inputs, precision);
|
||||
return;
|
||||
@@ -219,6 +223,70 @@ struct FunctionToDag {
|
||||
manp, slice(out_shape), comment);
|
||||
}
|
||||
|
||||
void addMul(optimizer::Dag &dag, FHE::MulEintOp &mulOp, Inputs &inputs,
|
||||
int precision) {
|
||||
|
||||
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
|
||||
|
||||
mlir::Value result = mulOp.getResult();
|
||||
const std::vector<uint64_t> resultShape = getShape(result);
|
||||
|
||||
Operation *xOp = mulOp.a().getDefiningOp();
|
||||
Operation *yOp = mulOp.b().getDefiningOp();
|
||||
|
||||
const double fixedCost = NEGLIGIBLE_COMPLEXITY;
|
||||
const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY;
|
||||
|
||||
llvm::APInt xSmanp = llvm::APInt{1, 1, false};
|
||||
if (xOp != nullptr) {
|
||||
const auto xSmanpAttr = xOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(xSmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
xSmanp = xSmanpAttr.getValue();
|
||||
}
|
||||
|
||||
llvm::APInt ySmanp = llvm::APInt{1, 1, false};
|
||||
if (yOp != nullptr) {
|
||||
const auto ySmanpAttr = yOp->getAttrOfType<mlir::IntegerAttr>("SMANP");
|
||||
assert(ySmanpAttr && "Missing SMANP value on a crypto operation");
|
||||
ySmanp = ySmanpAttr.getValue();
|
||||
}
|
||||
|
||||
auto loc = loc_to_string(mulOp.getLoc());
|
||||
auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc;
|
||||
|
||||
// (x + y) and (x - y)
|
||||
const double addSubManp =
|
||||
sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble());
|
||||
|
||||
// tlu(v)
|
||||
const double tluManp = 1;
|
||||
|
||||
// tlu(v1) - tlu(v2)
|
||||
const double tluSubManp = sqrt(tluManp + tluManp);
|
||||
|
||||
// for tlus
|
||||
const std::vector<std::uint64_t> unknownFunction;
|
||||
|
||||
// tlu(x + y)
|
||||
auto addNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
addSubManp, slice(resultShape), comment);
|
||||
auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision);
|
||||
|
||||
// tlu(x - y)
|
||||
auto subNode =
|
||||
dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost,
|
||||
addSubManp, slice(resultShape), comment);
|
||||
auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision);
|
||||
|
||||
// tlu(x + y) - tlu(x - y)
|
||||
const std::vector<concrete_optimizer::dag::OperatorIndex> subInputs = {
|
||||
lhsTluNode, rhsTluNode};
|
||||
index[result] =
|
||||
dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost,
|
||||
tluSubManp, slice(resultShape), comment);
|
||||
}
|
||||
|
||||
void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs,
|
||||
int precision) {
|
||||
mlir::Value result = maxOp.getResult();
|
||||
@@ -346,6 +414,10 @@ struct FunctionToDag {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHELinalg::Dot>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op);
|
||||
}
|
||||
|
||||
mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) {
|
||||
return llvm::dyn_cast<mlir::concretelang::FHE::MaxEintOp>(op);
|
||||
}
|
||||
|
||||
@@ -487,6 +487,30 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of
|
||||
/// `FHE.mul_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MulEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
// x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y)
|
||||
|
||||
const llvm::APInt x = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
const llvm::APInt y = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y);
|
||||
const llvm::APInt tlu = {1, 1, false};
|
||||
const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu);
|
||||
|
||||
// this is not optimal as it can increase the resulting noise unnecessarily
|
||||
return APIntUMax(beforeTLUs, result);
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.round` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -510,8 +534,8 @@ static llvm::APInt getSqMANP(
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
/// that is equivalent to an `FHE.max_eint` operation.
|
||||
/// Calculates the squared Minimal Arithmetic Noise Padding of
|
||||
/// `FHE.max_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::MaxEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -1245,6 +1269,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
} else if (auto mulEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::MulEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintOp, operands);
|
||||
} else if (auto roundOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::RoundEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(roundOp, operands);
|
||||
|
||||
Reference in New Issue
Block a user