mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: Do not assert fail with too large weight and fix computation of 2-Norm with negative weigth (close #892)
This commit is contained in:
@@ -16,7 +16,7 @@ unsigned int getEintPrecision(mlir::Value value);
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
|
||||
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createMaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> setMax);
|
||||
createMaxMANPPass(std::function<void(uint64_t, unsigned)> setMax);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -136,6 +136,7 @@ struct FunctionToDag {
|
||||
addDot(dag, val, encrypted_inputs, weightsOpt.getValue());
|
||||
return;
|
||||
}
|
||||
// If can't find weights return default leveled op
|
||||
DEBUG("Replace Dot by LevelledOp on " << op);
|
||||
}
|
||||
// default
|
||||
@@ -229,13 +230,16 @@ struct FunctionToDag {
|
||||
return value.isa<mlir::BlockArgument>();
|
||||
}
|
||||
|
||||
std::vector<std::int64_t>
|
||||
llvm::Optional<std::vector<std::int64_t>>
|
||||
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
|
||||
std::vector<std::int64_t> values;
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
|
||||
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
|
||||
if (val.getActiveBits() > 64) {
|
||||
return llvm::None;
|
||||
}
|
||||
values.push_back(val.getSExtValue());
|
||||
}
|
||||
return values;
|
||||
|
||||
@@ -168,12 +168,8 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
|
||||
|
||||
/// Calculates the square of the value of `i`.
|
||||
static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(i.getActiveBits() < 32 &&
|
||||
"Square of the constant cannot be represented on 64 bits");
|
||||
return llvm::APInt(2 * i.getActiveBits(),
|
||||
i.getZExtValue() * i.getZExtValue());
|
||||
llvm::APInt extI(2 * i.getActiveBits(), i.getSExtValue());
|
||||
return extI * extI;
|
||||
}
|
||||
|
||||
/// Calculates the square root of `i` and rounds it to the next highest
|
||||
@@ -1394,14 +1390,11 @@ struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
|
||||
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
|
||||
}
|
||||
MaxMANPPass() = delete;
|
||||
MaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> updateMax)
|
||||
: updateMax(updateMax), maxMANP(llvm::APInt{1, 0, false}),
|
||||
maxEintWidth(0){};
|
||||
MaxMANPPass(std::function<void(const uint64_t, unsigned)> updateMax)
|
||||
: updateMax(updateMax){};
|
||||
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
static const llvm::APInt one{1, 1, false};
|
||||
bool upd = false;
|
||||
|
||||
// Process all function arguments and use the default value of 1
|
||||
// for MANP and the declarend precision
|
||||
@@ -1410,15 +1403,7 @@ protected:
|
||||
for (mlir::BlockArgument blockArg : func.getBody().getArguments()) {
|
||||
if (isEncryptedFunctionParameter(blockArg)) {
|
||||
unsigned int width = fhe::utils::getEintPrecision(blockArg);
|
||||
|
||||
if (this->maxEintWidth < width) {
|
||||
this->maxEintWidth = width;
|
||||
}
|
||||
|
||||
if (APIntWidthExtendULT(this->maxMANP, one)) {
|
||||
this->maxMANP = one;
|
||||
upd = true;
|
||||
}
|
||||
this->updateMax(1, width);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1439,38 +1424,31 @@ protected:
|
||||
}
|
||||
|
||||
if (eTy) {
|
||||
if (this->maxEintWidth < eTy.getWidth()) {
|
||||
this->maxEintWidth = eTy.getWidth();
|
||||
upd = true;
|
||||
}
|
||||
|
||||
mlir::IntegerAttr MANP = op->getAttrOfType<mlir::IntegerAttr>("MANP");
|
||||
|
||||
if (!MANP) {
|
||||
op->emitError("Maximum Arithmetic Noise Padding value not set");
|
||||
op->emitError("2-Norm has not been computed");
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) {
|
||||
this->maxMANP = MANP.getValue();
|
||||
upd = true;
|
||||
auto manp = MANP.getValue();
|
||||
if (!manp.isIntN(64)) {
|
||||
op->emitError("2-Norm cannot be reprensented on 64bits");
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
this->updateMax(manp.getSExtValue(), eTy.getWidth());
|
||||
}
|
||||
}
|
||||
|
||||
if (upd)
|
||||
this->updateMax(this->maxMANP, this->maxEintWidth);
|
||||
}
|
||||
|
||||
std::function<void(const llvm::APInt &, unsigned)> updateMax;
|
||||
llvm::APInt maxMANP;
|
||||
unsigned int maxEintWidth;
|
||||
std::function<void(const uint64_t, unsigned)> updateMax;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMaxMANPPass(
|
||||
std::function<void(const llvm::APInt &, unsigned)> updateMax) {
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createMaxMANPPass(std::function<void(const uint64_t, unsigned)> updateMax) {
|
||||
return std::make_unique<MaxMANPPass>(updateMax);
|
||||
}
|
||||
|
||||
|
||||
@@ -94,24 +94,14 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
|
||||
unsigned currMaxWidth) {
|
||||
assert((uint64_t)currMaxWidth < std::numeric_limits<size_t>::max() &&
|
||||
"Maximum width does not fit into size_t");
|
||||
mlir::concretelang::createMaxMANPPass(
|
||||
[&](const uint64_t manp, unsigned width) {
|
||||
if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp)
|
||||
oMax2norm.emplace(manp);
|
||||
|
||||
assert(sizeof(uint64_t) >= sizeof(size_t) &&
|
||||
currMaxMANP.ult(std::numeric_limits<size_t>::max()) &&
|
||||
"Maximum MANP does not fit into size_t");
|
||||
|
||||
size_t manp = (size_t)currMaxMANP.getZExtValue();
|
||||
size_t width = (size_t)currMaxWidth;
|
||||
|
||||
if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp)
|
||||
oMax2norm.emplace(manp);
|
||||
|
||||
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
|
||||
oMaxWidth.emplace(width);
|
||||
}),
|
||||
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
|
||||
oMaxWidth.emplace(width);
|
||||
}),
|
||||
enablePass);
|
||||
if (pm.run(module.getOperation()).failed()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
|
||||
@@ -183,6 +183,17 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant -1 : i3
|
||||
|
||||
// %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>)
|
||||
|
||||
%cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3>
|
||||
// sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 56 : ui{{[[0-9]+}}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 28 : ui{{[[0-9]+}}}
|
||||
%1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
|
||||
return %1 : !FHE.eint<2>
|
||||
|
||||
Reference in New Issue
Block a user