From e1fb417c54c13049e694c4c87c19c9f02389e126 Mon Sep 17 00:00:00 2001 From: rudy Date: Mon, 22 Aug 2022 13:09:14 +0200 Subject: [PATCH] fix(optimizer): no ceiling for MANP value given to the optimizer --- .../lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp | 11 +++++------ compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 7 +++++++ .../tests/check_tests/Dialect/FHE/Analysis/MANP.mlir | 2 +- .../check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir | 2 +- .../check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir | 2 +- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 98c638e32..96b3ba6c3 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -4,6 +4,7 @@ // for license information. #include +#include #include #include @@ -75,6 +76,7 @@ struct FunctionToDag { for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { addOperation(dag, op); + op.removeAttr("SMANP"); } } if (index.empty()) { @@ -178,13 +180,10 @@ struct FunctionToDag { // Default complexity is negligible double fixed_cost = NEGLIGIBLE_COMPLEXITY; double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; - auto manp_int = op.getAttrOfType("MANP"); + auto smanp_int = op.getAttrOfType("SMANP"); auto loc = loc_to_string(op.getLoc()); - if (!manp_int) { - DEBUG("Cannot read manp on " << op << "\n" << loc); - } - assert(manp_int && "Missing manp value on a crypto operation"); - double manp = (double)manp_int.getValue().getZExtValue(); + assert(smanp_int && "Missing manp value on a crypto operation"); + double manp = sqrt((double)smanp_int.getValue().getZExtValue()); auto comment = std::string(op.getName().getStringRef()) + " " + loc; index[val] = dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 521a842f1..513a13a3e 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -1470,6 +1470,13 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { latticeRes.join(MANPLatticeValue{norm2SqEquiv}); latticeRes.markOptimisticFixpoint(); + op->setAttr("SMANP", + mlir::IntegerAttr::get( + mlir::IntegerType::get( + op->getContext(), norm2SqEquiv.getBitWidth(), + mlir::IntegerType::SignednessSemantics::Unsigned), + norm2SqEquiv)); + llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv); op->setAttr("MANP", diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index b42876a84..30a14e4e2 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s func.func @single_zero() -> !FHE.eint<2> { diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 150602d06..a1a2cbcd7 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes canonicalize --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s func.func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir index be65deb0d..25bc57f7c 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s func.func @tensor_from_elements_1(%a: !FHE.eint<2>, %b: !FHE.eint<2>, %c: !FHE.eint<2>, %d: !FHE.eint<2>) -> tensor<4x!FHE.eint<2>> {